mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 07:46:38 +00:00
Compare commits
24 Commits
log/conn-d
...
v0.64.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fdc00ff41 | ||
|
|
b20d484972 | ||
|
|
8931293343 | ||
|
|
7b830d8f72 | ||
|
|
3a0cf230a1 | ||
|
|
0c990ab662 | ||
|
|
101c813e98 | ||
|
|
5333e55a81 | ||
|
|
81c11df103 | ||
|
|
f74bc48d16 | ||
|
|
0169e4540f | ||
|
|
cead3f38ee | ||
|
|
b55262d4a2 | ||
|
|
2248ff392f | ||
|
|
06966da012 | ||
|
|
d4f7df271a | ||
|
|
5299549eb6 | ||
|
|
7d791620a6 | ||
|
|
44ab454a13 | ||
|
|
11f50d6c38 | ||
|
|
05af39a69b | ||
|
|
074df56c3d | ||
|
|
2381e216e4 | ||
|
|
ded04b7627 |
@@ -60,8 +60,8 @@
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
### Self-Host NetBird (Video)
|
||||
[](https://youtu.be/bZAgpT6nzaQ)
|
||||
|
||||
### Key features
|
||||
|
||||
|
||||
@@ -3,15 +3,7 @@ package android
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
}
|
||||
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||
}
|
||||
|
||||
if !supportsSSO {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
return true, err
|
||||
}
|
||||
@@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
||||
}
|
||||
|
||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
}
|
||||
return backoffErr
|
||||
})
|
||||
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
@@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||
var needsLogin bool
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||
return
|
||||
})
|
||||
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
|
||||
if err == nil {
|
||||
go urlOpener.OnLoginSuccess()
|
||||
}
|
||||
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
err, _ = authClient.Login(a.ctx, "", jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
go urlOpener.OnLoginSuccess()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
@@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||
return backoff.RetryNotify(
|
||||
bf,
|
||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -277,18 +276,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
needsLogin := false
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, "", "")
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
err, isAuthError := authClient.Login(ctx, "", "")
|
||||
if isAuthError {
|
||||
needsLogin = true
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("login check failed: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
@@ -300,23 +300,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
var lastError error
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
lastError = err
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if lastError != nil {
|
||||
return fmt.Errorf("login failed: %v", lastError)
|
||||
}
|
||||
|
||||
err, _ = authClient.Login(ctx, setupKey, jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -344,11 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
defer c()
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
@@ -68,6 +69,8 @@ type Options struct {
|
||||
StatePath string
|
||||
// DisableClientRoutes disables the client routes
|
||||
DisableClientRoutes bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -136,6 +139,7 @@ func New(opts Options) (*Client, error) {
|
||||
PreSharedKey: &opts.PreSharedKey,
|
||||
DisableServerRoutes: &t,
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
}
|
||||
if opts.ConfigPath != "" {
|
||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||
@@ -176,7 +180,13 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create auth client: %w", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -83,6 +83,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChain(); err != nil {
|
||||
return fmt.Errorf("init notrack chain: %w", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
@@ -177,6 +181,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||
}
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
@@ -277,6 +285,125 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
tableRaw = "raw"
|
||||
)
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||
//
|
||||
// Traffic flows that need NOTRACK:
|
||||
//
|
||||
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||
// Matched by: sport=wgPort
|
||||
//
|
||||
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 3. Ingress: Packets to WireGuard
|
||||
// dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||
// dst=127.0.0.1:proxyPort
|
||||
// Matched by: dport=proxyPort
|
||||
//
|
||||
// Rules are cleaned up when the firewall manager is closed.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
wgPortStr := fmt.Sprintf("%d", wgPort)
|
||||
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
||||
|
||||
// Egress rules: match outgoing loopback UDP packets
|
||||
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
|
||||
return fmt.Errorf("add output sport notrack rule: %w", err)
|
||||
}
|
||||
|
||||
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
|
||||
return fmt.Errorf("add output dport notrack rule: %w", err)
|
||||
}
|
||||
|
||||
// Ingress rules: match incoming loopback UDP packets
|
||||
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
|
||||
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
|
||||
}
|
||||
|
||||
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
|
||||
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) initNoTrackChain() error {
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
log.Debugf("cleanup notrack chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
|
||||
return fmt.Errorf("create chain: %w", err)
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
|
||||
log.Debugf("delete output jump rule: %v", delErr)
|
||||
}
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNoTrackChain() error {
|
||||
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
|
||||
return fmt.Errorf("clear and delete chain: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -168,6 +168,10 @@ type Manager interface {
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -48,8 +49,10 @@ type Manager struct {
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
notrackOutputChain *nftables.Chain
|
||||
notrackPreroutingChain *nftables.Chain
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
@@ -91,6 +94,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChains(workTable); err != nil {
|
||||
return fmt.Errorf("init notrack chains: %w", err)
|
||||
}
|
||||
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
@@ -288,7 +295,15 @@ func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclManager.Flush()
|
||||
if err := m.aclManager.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.refreshNoTrackChains(); err != nil {
|
||||
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
@@ -331,6 +346,176 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRawOutput = "netbird-raw-out"
|
||||
chainNameRawPrerouting = "netbird-raw-pre"
|
||||
)
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||
//
|
||||
// Traffic flows that need NOTRACK:
|
||||
//
|
||||
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||
// Matched by: sport=wgPort
|
||||
//
|
||||
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 3. Ingress: Packets to WireGuard
|
||||
// dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||
// dst=127.0.0.1:proxyPort
|
||||
// Matched by: dport=proxyPort
|
||||
//
|
||||
// Rules are cleaned up when the firewall manager is closed.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
|
||||
return fmt.Errorf("notrack chains not initialized")
|
||||
}
|
||||
|
||||
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
|
||||
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
|
||||
loopback := []byte{127, 0, 0, 1}
|
||||
|
||||
// Egress rules: match outgoing loopback UDP packets
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackOutputChain.Table,
|
||||
Chain: m.notrackOutputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackOutputChain.Table,
|
||||
Chain: m.notrackOutputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
|
||||
// Ingress rules: match incoming loopback UDP packets
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackPreroutingChain.Table,
|
||||
Chain: m.notrackPreroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackPreroutingChain.Table,
|
||||
Chain: m.notrackPreroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush notrack rules: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
|
||||
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawOutput,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityRaw,
|
||||
})
|
||||
|
||||
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawPrerouting,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityRaw,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush chain creation: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) refreshNoTrackChains() error {
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
for _, c := range chains {
|
||||
if c.Table.Name != tableName {
|
||||
continue
|
||||
}
|
||||
switch c.Name {
|
||||
case chainNameRawOutput:
|
||||
m.notrackOutputChain = c
|
||||
case chainNameRawPrerouting:
|
||||
m.notrackPreroutingChain = c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -570,6 +570,14 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
|
||||
}
|
||||
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
|
||||
@@ -558,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
||||
host, portStr, err := net.SplitHostPort(val)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse endpoint: %v", err)
|
||||
continue
|
||||
|
||||
@@ -50,6 +50,7 @@ func ValidateMTU(mtu uint16) error {
|
||||
|
||||
type wgProxyFactory interface {
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetProxyPort() uint16
|
||||
Free() error
|
||||
}
|
||||
|
||||
@@ -80,6 +81,12 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
|
||||
return w.wgProxyFactory.GetProxy()
|
||||
}
|
||||
|
||||
// GetProxyPort returns the proxy port used by the WireGuard proxy.
|
||||
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
|
||||
func (w *WGIface) GetProxyPort() uint16 {
|
||||
return w.wgProxyFactory.GetProxyPort()
|
||||
}
|
||||
|
||||
// GetBind returns the EndpointManager userspace bind mode.
|
||||
func (w *WGIface) GetBind() device.EndpointManager {
|
||||
w.mu.Lock()
|
||||
|
||||
@@ -114,34 +114,21 @@ func (p *ProxyBind) Pause() {
|
||||
}
|
||||
|
||||
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||
ep, err := addrToEndpoint(endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to start package redirection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
ep, err := addrToEndpoint(endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to convert endpoint address: %v", err)
|
||||
} else {
|
||||
p.wgCurrentUsed = ep
|
||||
}
|
||||
p.wgCurrentUsed = ep
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
||||
if addr == nil {
|
||||
return nil, errors.New("nil address")
|
||||
}
|
||||
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
|
||||
func (p *ProxyBind) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
@@ -225,3 +212,16 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||
return &netipAddr, nil
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
||||
if addr == nil {
|
||||
return nil, fmt.Errorf("invalid address")
|
||||
}
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -26,13 +24,10 @@ const (
|
||||
loopbackAddr = "127.0.0.1"
|
||||
)
|
||||
|
||||
var (
|
||||
localHostNetIP = net.ParseIP("127.0.0.1")
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
localWGListenPort int
|
||||
proxyPort int
|
||||
mtu uint16
|
||||
|
||||
ebpfManager ebpfMgr.Manager
|
||||
@@ -40,7 +35,8 @@ type WGEBPFProxy struct {
|
||||
turnConnMutex sync.Mutex
|
||||
|
||||
lastUsedPort uint16
|
||||
rawConn net.PacketConn
|
||||
rawConnIPv4 net.PacketConn
|
||||
rawConnIPv6 net.PacketConn
|
||||
conn transport.UDPConn
|
||||
|
||||
ctx context.Context
|
||||
@@ -62,23 +58,39 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
|
||||
// Listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) Listen() error {
|
||||
pl := portLookup{}
|
||||
wgPorxyPort, err := pl.searchFreePort()
|
||||
proxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.proxyPort = proxyPort
|
||||
|
||||
// Prepare IPv4 raw socket (required)
|
||||
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
||||
// Prepare IPv6 raw socket (optional)
|
||||
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||
if err != nil {
|
||||
return err
|
||||
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
|
||||
}
|
||||
|
||||
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
|
||||
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
|
||||
if err != nil {
|
||||
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
|
||||
}
|
||||
if p.rawConnIPv6 != nil {
|
||||
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
addr := net.UDPAddr{
|
||||
Port: wgPorxyPort,
|
||||
Port: proxyPort,
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
}
|
||||
|
||||
@@ -94,7 +106,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
p.conn = conn
|
||||
|
||||
go p.proxyToRemote()
|
||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
||||
log.Infof("local wg proxy listening on: %d", proxyPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -135,12 +147,25 @@ func (p *WGEBPFProxy) Free() error {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
if err := p.rawConn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
if p.rawConnIPv4 != nil {
|
||||
if err := p.rawConnIPv4.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
if p.rawConnIPv6 != nil {
|
||||
if err := p.rawConnIPv6.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// GetProxyPort returns the proxy listening port.
|
||||
func (p *WGEBPFProxy) GetProxyPort() uint16 {
|
||||
return uint16(p.proxyPort)
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
// From this go routine has only one instance.
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
@@ -216,34 +241,3 @@ generatePort:
|
||||
}
|
||||
return p.lastUsedPort, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
||||
payload := gopacket.Payload(data)
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: localHostNetIP,
|
||||
SrcIP: endpointAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(endpointAddr.Port),
|
||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
layerBuffer := gopacket.NewSerializeBuffer()
|
||||
|
||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,12 +10,89 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||
)
|
||||
|
||||
var (
|
||||
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
|
||||
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
|
||||
|
||||
localHostNetIPv4 = net.ParseIP("127.0.0.1")
|
||||
localHostNetIPv6 = net.ParseIP("::1")
|
||||
|
||||
serializeOpts = gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
)
|
||||
|
||||
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
|
||||
type PacketHeaders struct {
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH *layers.UDP
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
localHostAddr net.IP
|
||||
isIPv4 bool
|
||||
}
|
||||
|
||||
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
|
||||
var ipH gopacket.SerializableLayer
|
||||
var networkLayer gopacket.NetworkLayer
|
||||
var localHostAddr net.IP
|
||||
var isIPv4 bool
|
||||
|
||||
// Check if source address is IPv4 or IPv6
|
||||
if endpoint.IP.To4() != nil {
|
||||
// IPv4 path
|
||||
ipv4 := &layers.IPv4{
|
||||
DstIP: localHostNetIPv4,
|
||||
SrcIP: endpoint.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv4
|
||||
networkLayer = ipv4
|
||||
localHostAddr = localHostNetIPv4
|
||||
isIPv4 = true
|
||||
} else {
|
||||
// IPv6 path
|
||||
ipv6 := &layers.IPv6{
|
||||
DstIP: localHostNetIPv6,
|
||||
SrcIP: endpoint.IP,
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv6
|
||||
networkLayer = ipv6
|
||||
localHostAddr = localHostNetIPv6
|
||||
isIPv4 = false
|
||||
}
|
||||
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(endpoint.Port),
|
||||
DstPort: layers.UDPPort(localWGListenPort),
|
||||
}
|
||||
|
||||
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
return &PacketHeaders{
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
localHostAddr: localHostAddr,
|
||||
isIPv4: isIPv4,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
type ProxyWrapper struct {
|
||||
wgeBPFProxy *WGEBPFProxy
|
||||
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgRelayedEndpointAddr *net.UDPAddr
|
||||
wgEndpointCurrentUsedAddr *net.UDPAddr
|
||||
wgRelayedEndpointAddr *net.UDPAddr
|
||||
headers *PacketHeaders
|
||||
headerCurrentUsed *PacketHeaders
|
||||
rawConn net.PacketConn
|
||||
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
}
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
|
||||
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create packet sender: %w", err)
|
||||
}
|
||||
|
||||
// Check if required raw connection is available
|
||||
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||
return errIPv6ConnNotAvailable
|
||||
}
|
||||
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||
return errIPv4ConnNotAvailable
|
||||
}
|
||||
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgRelayedEndpointAddr = addr
|
||||
return err
|
||||
p.headers = headers
|
||||
p.rawConn = p.selectRawConn(headers)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
||||
p.headerCurrentUsed = p.headers
|
||||
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
@@ -91,12 +188,32 @@ func (p *ProxyWrapper) Pause() {
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
if endpoint == nil || endpoint.IP == nil {
|
||||
log.Errorf("failed to start package redirection, endpoint is nil")
|
||||
return
|
||||
}
|
||||
|
||||
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create packet headers: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if required raw connection is available
|
||||
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||
log.Error(errIPv6ConnNotAvailable)
|
||||
return
|
||||
}
|
||||
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||
log.Error(errIPv4ConnNotAvailable)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
if endpoint != nil && endpoint.IP != nil {
|
||||
p.wgEndpointCurrentUsedAddr = endpoint
|
||||
}
|
||||
p.headerCurrentUsed = header
|
||||
p.rawConn = p.selectRawConn(header)
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
@@ -138,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
p.pausedCond.Wait()
|
||||
}
|
||||
|
||||
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
||||
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err != nil {
|
||||
@@ -164,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
|
||||
defer func() {
|
||||
if err := header.layerBuffer.Clear(); err != nil {
|
||||
log.Errorf("failed to clear layer buffer: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
|
||||
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
|
||||
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
|
||||
if header.isIPv4 {
|
||||
return p.wgeBPFProxy.rawConnIPv4
|
||||
}
|
||||
return p.wgeBPFProxy.rawConnIPv6
|
||||
}
|
||||
|
||||
@@ -54,6 +54,14 @@ func (w *KernelFactory) GetProxy() Proxy {
|
||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||
}
|
||||
|
||||
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
|
||||
func (w *KernelFactory) GetProxyPort() uint16 {
|
||||
if w.ebpfProxy == nil {
|
||||
return 0
|
||||
}
|
||||
return w.ebpfProxy.GetProxyPort()
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
if w.ebpfProxy == nil {
|
||||
return nil
|
||||
|
||||
@@ -24,6 +24,11 @@ func (w *USPFactory) GetProxy() Proxy {
|
||||
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
||||
}
|
||||
|
||||
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
|
||||
func (w *USPFactory) GetProxyPort() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (w *USPFactory) Free() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,43 +8,87 @@ import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
|
||||
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
|
||||
return prepareSenderRawSocket(syscall.AF_INET, true)
|
||||
}
|
||||
|
||||
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
|
||||
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
|
||||
return prepareSenderRawSocket(syscall.AF_INET6, false)
|
||||
}
|
||||
|
||||
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
|
||||
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
|
||||
if isIPv4 {
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
if closeErr := file.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
// Close the original file to release the FD (net.FilePacketConn duplicates it)
|
||||
if closeErr := file.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
353
client/iface/wgproxy/redirect_test.go
Normal file
353
client/iface/wgproxy/redirect_test.go
Normal file
@@ -0,0 +1,353 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
|
||||
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
|
||||
func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
||||
udpAddr1, ok1 := addr1.(*net.UDPAddr)
|
||||
udpAddr2, ok2 := addr2.(*net.UDPAddr)
|
||||
|
||||
if !ok1 || !ok2 {
|
||||
return addr1.String() == addr2.String()
|
||||
}
|
||||
|
||||
// Compare IP and Port, ignoring zone
|
||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||
wgPort := 51852
|
||||
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
|
||||
func TestRedirectAs_UDP_IPv6(t *testing.T) {
|
||||
wgPort := 51853
|
||||
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// testRedirectAs is a helper function that tests the RedirectAs functionality
|
||||
// It verifies that:
|
||||
// 1. Initial traffic from relay connection works
|
||||
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
|
||||
// 3. Multiple packets are correctly redirected with the new source address
|
||||
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
|
||||
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
|
||||
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener4.Close()
|
||||
|
||||
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
|
||||
IP: net.ParseIP("::1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener6.Close()
|
||||
|
||||
// Determine which listener to use based on the NetBird address IP version
|
||||
// (this is where initial traffic will come from before RedirectAs is called)
|
||||
var wgListener *net.UDPConn
|
||||
if p2pEndpoint.IP.To4() == nil {
|
||||
wgListener = wgListener6
|
||||
} else {
|
||||
wgListener = wgListener4
|
||||
}
|
||||
|
||||
// Create relay server and connection
|
||||
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0, // Random port
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay server: %v", err)
|
||||
}
|
||||
defer relayServer.Close()
|
||||
|
||||
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay connection: %v", err)
|
||||
}
|
||||
defer relayConn.Close()
|
||||
|
||||
// Add TURN connection to proxy
|
||||
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||
t.Fatalf("failed to add TURN connection: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("failed to close proxy connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start the proxy
|
||||
proxy.Work()
|
||||
|
||||
// Phase 1: Test initial relay traffic
|
||||
msgFromRelay := []byte("hello from relay")
|
||||
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write to relay server: %v", err)
|
||||
}
|
||||
|
||||
// Set read deadline to avoid hanging
|
||||
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, _, err := wgListener4.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from WireGuard listener: %v", err)
|
||||
}
|
||||
|
||||
if n != len(msgFromRelay) {
|
||||
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(msgFromRelay) {
|
||||
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
|
||||
}
|
||||
|
||||
// Phase 2: Redirect to p2p endpoint
|
||||
proxy.RedirectAs(p2pEndpoint)
|
||||
|
||||
// Give the proxy a moment to process the redirect
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Phase 3: Test redirected traffic
|
||||
redirectedMessages := [][]byte{
|
||||
[]byte("redirected message 1"),
|
||||
[]byte("redirected message 2"),
|
||||
[]byte("redirected message 3"),
|
||||
}
|
||||
|
||||
for i, msg := range redirectedMessages {
|
||||
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
|
||||
}
|
||||
|
||||
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
|
||||
}
|
||||
|
||||
// Verify message content
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
|
||||
}
|
||||
|
||||
// Verify source address matches p2p endpoint (this is the key test)
|
||||
// Use compareUDPAddr to ignore IPv6 zone IDs
|
||||
if !compareUDPAddr(srcAddr, p2pEndpoint) {
|
||||
t.Errorf("message %d: expected source address %s, got %s",
|
||||
i+1, p2pEndpoint.String(), srcAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||
wgPort := 51856
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create WireGuard listener
|
||||
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener.Close()
|
||||
|
||||
// Create relay server and connection
|
||||
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay server: %v", err)
|
||||
}
|
||||
defer relayServer.Close()
|
||||
|
||||
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay connection: %v", err)
|
||||
}
|
||||
defer relayConn.Close()
|
||||
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||
t.Fatalf("failed to add TURN connection: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("failed to close proxy connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy.Work()
|
||||
|
||||
// Test switching between multiple endpoints - using addresses in local subnet
|
||||
endpoints := []*net.UDPAddr{
|
||||
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
|
||||
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
|
||||
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
|
||||
}
|
||||
|
||||
for i, endpoint := range endpoints {
|
||||
proxy.RedirectAs(endpoint)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
msg := []byte("test message")
|
||||
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
|
||||
}
|
||||
|
||||
if !compareUDPAddr(srcAddr, endpoint) {
|
||||
t.Errorf("endpoint %d: expected source %s, got %s",
|
||||
i, endpoint.String(), srcAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||
dialer := net.Dialer{}
|
||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
|
||||
@@ -19,37 +19,56 @@ var (
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
localHostNetIPAddr = &net.IPAddr{
|
||||
localHostNetIPAddrV4 = &net.IPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
localHostNetIPAddrV6 = &net.IPAddr{
|
||||
IP: net.ParseIP("::1"),
|
||||
}
|
||||
)
|
||||
|
||||
type SrcFaker struct {
|
||||
srcAddr *net.UDPAddr
|
||||
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
localHostAddr *net.IPAddr
|
||||
}
|
||||
|
||||
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
||||
// Create only the raw socket for the address family we need
|
||||
var rawSocket net.PacketConn
|
||||
var err error
|
||||
var localHostAddr *net.IPAddr
|
||||
|
||||
if srcAddr.IP.To4() != nil {
|
||||
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||
localHostAddr = localHostNetIPAddrV4
|
||||
} else {
|
||||
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||
localHostAddr = localHostNetIPAddrV6
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||
if err != nil {
|
||||
if closeErr := rawSocket.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket: %v", closeErr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := &SrcFaker{
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
localHostAddr: localHostAddr,
|
||||
}
|
||||
|
||||
return f, nil
|
||||
@@ -72,7 +91,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
@@ -80,19 +99,40 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: net.ParseIP("127.0.0.1"),
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
var ipH gopacket.SerializableLayer
|
||||
var networkLayer gopacket.NetworkLayer
|
||||
|
||||
// Check if source IP is IPv4 or IPv6
|
||||
if srcAddr.IP.To4() != nil {
|
||||
// IPv4
|
||||
ipv4 := &layers.IPv4{
|
||||
DstIP: localHostNetIPAddrV4.IP,
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv4
|
||||
networkLayer = ipv4
|
||||
} else {
|
||||
// IPv6
|
||||
ipv6 := &layers.IPv6{
|
||||
DstIP: localHostNetIPAddrV6.IP,
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv6
|
||||
networkLayer = ipv6
|
||||
}
|
||||
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
err := udpH.SetNetworkLayerForChecksum(networkLayer)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
499
client/internal/auth/auth.go
Normal file
499
client/internal/auth/auth.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// Auth manages authentication operations with the management server
|
||||
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||
type Auth struct {
|
||||
mutex sync.RWMutex
|
||||
client *mgm.GrpcClient
|
||||
config *profilemanager.Config
|
||||
privateKey wgtypes.Key
|
||||
mgmURL *url.URL
|
||||
mgmTLSEnabled bool
|
||||
}
|
||||
|
||||
// NewAuth creates a new Auth instance that manages authentication flows
|
||||
// It establishes a connection to the management server that will be reused for all operations
|
||||
// The connection is automatically recreated with backoff if it becomes disconnected
|
||||
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
||||
// Validate WireGuard private key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine TLS setting based on URL scheme
|
||||
mgmTLSEnabled := mgmURL.Scheme == "https"
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
return &Auth{
|
||||
client: mgmClient,
|
||||
config: config,
|
||||
privateKey: myPrivateKey,
|
||||
mgmURL: mgmURL,
|
||||
mgmTLSEnabled: mgmTLSEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the management client connection
|
||||
func (a *Auth) Close() error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
if a.client == nil {
|
||||
return nil
|
||||
}
|
||||
return a.client.Close()
|
||||
}
|
||||
|
||||
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
||||
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
||||
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
||||
var supportsSSO bool
|
||||
|
||||
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
// Try PKCE flow first
|
||||
_, err := a.getPKCEFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if PKCE is not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// PKCE not supported, try Device flow
|
||||
_, err = a.getDeviceFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if Device flow is also not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// Neither PKCE nor Device flow is supported
|
||||
supportsSSO = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Device flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
}
|
||||
|
||||
// PKCE flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
})
|
||||
|
||||
return supportsSSO, err
|
||||
}
|
||||
|
||||
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
|
||||
// This avoids creating a new connection to the management server
|
||||
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
|
||||
var flow OAuthFlow
|
||||
var err error
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
if forceDeviceAuth {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
|
||||
// Try PKCE flow first
|
||||
flow, err = a.getPKCEFlow(client)
|
||||
if err != nil {
|
||||
// If PKCE not supported, try Device flow
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return flow, err
|
||||
}
|
||||
|
||||
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
needsLogin = false
|
||||
return err
|
||||
})
|
||||
|
||||
return needsLogin, err
|
||||
}
|
||||
|
||||
// Login attempts to log in or register the client with the management server
|
||||
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
|
||||
isAuthError = false
|
||||
return nil
|
||||
})
|
||||
|
||||
return err, isAuthError
|
||||
}
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &PKCEAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
RedirectURLs: protoConfig.GetRedirectURLs(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
ClientCertPair: a.config.ClientCertKeyPair,
|
||||
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
||||
}
|
||||
|
||||
if err := validatePKCEConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewPKCEAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &DeviceAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
Domain: protoConfig.Domain,
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
}
|
||||
|
||||
// Keep compatibility with older management versions
|
||||
if config.Scope == "" {
|
||||
config.Scope = "openid"
|
||||
}
|
||||
|
||||
if err := validateDeviceAuthConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewDeviceAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("peer has been successfully registered on Management Service")
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// setSystemInfoFlags sets all configuration flags on the provided system info
|
||||
func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
info.SetFlags(
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
a.config.DisableFirewall,
|
||||
a.config.BlockLANAccess,
|
||||
a.config.BlockInbound,
|
||||
a.config.LazyConnectionEnabled,
|
||||
a.config.EnableSSHRoot,
|
||||
a.config.EnableSSHSFTP,
|
||||
a.config.EnableSSHLocalPortForwarding,
|
||||
a.config.EnableSSHRemotePortForwarding,
|
||||
a.config.DisableSSHAuth,
|
||||
)
|
||||
}
|
||||
|
||||
// reconnect closes the current connection and creates a new one
|
||||
// It checks if the brokenClient is still the current client before reconnecting
|
||||
// to avoid multiple threads reconnecting unnecessarily
|
||||
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
// Double-check: if client has already been replaced by another thread, skip reconnection
|
||||
if a.client != brokenClient {
|
||||
log.Debugf("client already reconnected by another thread, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create new connection FIRST, before closing the old one
|
||||
// This ensures a.client is never nil, preventing panics in other threads
|
||||
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
|
||||
// Keep the old client if reconnection fails
|
||||
return err
|
||||
}
|
||||
|
||||
// Close old connection AFTER new one is successfully created
|
||||
oldClient := a.client
|
||||
a.client = mgmClient
|
||||
|
||||
if oldClient != nil {
|
||||
if err := oldClient.Close(); err != nil {
|
||||
log.Debugf("error closing old connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// These error codes indicate connection issues
|
||||
return s.Code() == codes.Unavailable ||
|
||||
s.Code() == codes.DeadlineExceeded ||
|
||||
s.Code() == codes.Canceled ||
|
||||
s.Code() == codes.Internal
|
||||
}
|
||||
|
||||
// withRetry wraps an operation with exponential backoff retry logic
|
||||
// It automatically reconnects on connection errors
|
||||
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
|
||||
backoffSettings := &backoff.ExponentialBackOff{
|
||||
InitialInterval: 500 * time.Millisecond,
|
||||
RandomizationFactor: 0.5,
|
||||
Multiplier: 1.5,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 2 * time.Minute,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
backoffSettings.Reset()
|
||||
|
||||
return backoff.RetryNotify(
|
||||
func() error {
|
||||
// Capture the client BEFORE the operation to ensure we track the correct client
|
||||
a.mutex.RLock()
|
||||
currentClient := a.client
|
||||
a.mutex.RUnlock()
|
||||
|
||||
if currentClient == nil {
|
||||
return status.Errorf(codes.Unavailable, "client is not initialized")
|
||||
}
|
||||
|
||||
// Execute operation with the captured client
|
||||
err := operation(currentClient)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If it's a connection error, attempt reconnection using the client that was actually used
|
||||
if isConnectionError(err) {
|
||||
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
||||
|
||||
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
|
||||
log.Errorf("reconnection failed: %v", reconnectErr)
|
||||
return reconnectErr
|
||||
}
|
||||
// Return the original error to trigger retry with the new connection
|
||||
return err
|
||||
}
|
||||
|
||||
// For authentication errors, don't retry
|
||||
if isAuthenticationError(err) {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
backoff.WithContext(backoffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
|
||||
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
|
||||
func isAuthenticationError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
// isPermissionDenied checks if the error is a PermissionDenied error.
|
||||
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
|
||||
func isPermissionDenied(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
return isAuthenticationError(err)
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
return isPermissionDenied(err)
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
@@ -26,12 +25,56 @@ const (
|
||||
|
||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validateDeviceAuthConfig validates device authorization provider configuration
|
||||
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||
// for the Device Authorization Flow.
|
||||
type DeviceAuthorizationFlow struct {
|
||||
providerConfig internal.DeviceAuthProviderConfig
|
||||
|
||||
HTTPClient HTTPClient
|
||||
providerConfig DeviceAuthProviderConfig
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
|
||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
|
||||
}
|
||||
|
||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
||||
return d.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the device authorization flow
|
||||
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
|
||||
d.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||
form := url.Values{}
|
||||
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
|
||||
}
|
||||
|
||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning.
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
interval := time.Duration(info.Interval) * time.Second
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case <-ticker.C:
|
||||
|
||||
tokenResponse, err := d.requestToken(info)
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
err: testCase.inputReqError,
|
||||
}
|
||||
|
||||
deviceFlow := &DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
countResBody: testCase.inputCountResBody,
|
||||
}
|
||||
|
||||
deviceFlow := DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
if hint != "" {
|
||||
pkceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
return pkceFlowInfo, nil
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
case ok && s.Code() == codes.NotFound:
|
||||
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
if hint != "" {
|
||||
deviceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
return deviceFlowInfo, nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/templates"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
@@ -35,17 +34,67 @@ const (
|
||||
defaultPKCETimeoutSeconds = 300
|
||||
)
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validatePKCEConfig validates PKCE provider configuration
|
||||
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||
// the Authorization Code Flow with PKCE.
|
||||
type PKCEAuthorizationFlow struct {
|
||||
providerConfig internal.PKCEAuthProviderConfig
|
||||
providerConfig PKCEAuthProviderConfig
|
||||
state string
|
||||
codeVerifier string
|
||||
oAuthConfig *oauth2.Config
|
||||
}
|
||||
|
||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
var availableRedirectURL string
|
||||
|
||||
excludedRanges := getSystemExcludedPortRanges()
|
||||
@@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the PKCE authorization flow
|
||||
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
|
||||
p.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
tokenChan := make(chan *oauth2.Token, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
@@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
@@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
go p.startServer(server, tokenChan, errChan)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case token := <-tokenChan:
|
||||
return p.parseOAuthToken(token)
|
||||
case err := <-errChan:
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
@@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
config := PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestParseExcludedPortRanges(t *testing.T) {
|
||||
@@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||
|
||||
availablePort := 65432
|
||||
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
config := PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
type DeviceAuthorizationFlow struct {
|
||||
Provider string
|
||||
ProviderConfig DeviceAuthProviderConfig
|
||||
}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: DeviceAuthProviderConfig{
|
||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
},
|
||||
}
|
||||
|
||||
// keep compatibility with older management versions
|
||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||
}
|
||||
|
||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return deviceAuthorizationFlow, nil
|
||||
}
|
||||
|
||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD exact match",
|
||||
handlerDomain: "example.x.",
|
||||
queryDomain: "example.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD subdomain match",
|
||||
handlerDomain: "example.x.",
|
||||
queryDomain: "sub.example.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD wildcard match",
|
||||
handlerDomain: "*.example.x.",
|
||||
queryDomain: "sub.example.x.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "two letter domain labels",
|
||||
handlerDomain: "a.b.",
|
||||
queryDomain: "a.b.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single character domain",
|
||||
handlerDomain: "x.",
|
||||
queryDomain: "x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single character domain with subdomain match",
|
||||
handlerDomain: "x.",
|
||||
queryDomain: "sub.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
@@ -38,6 +40,9 @@ const (
|
||||
type systemConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
systemDNSSettings SystemDNSSettings
|
||||
|
||||
mu sync.RWMutex
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
func newHostManager() (*systemConfigurator, error) {
|
||||
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
}
|
||||
|
||||
var dnsSettings SystemDNSSettings
|
||||
var serverAddresses []netip.Addr
|
||||
inSearchDomainsArray := false
|
||||
inServerAddressesArray := false
|
||||
|
||||
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
||||
} else if inServerAddressesArray {
|
||||
address := strings.Split(line, " : ")[1]
|
||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
||||
dnsSettings.ServerIP = ip.Unmap()
|
||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
||||
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
|
||||
ip = ip.Unmap()
|
||||
serverAddresses = append(serverAddresses, ip)
|
||||
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
|
||||
dnsSettings.ServerIP = ip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
// default to 53 port
|
||||
dnsSettings.ServerPort = DefaultPort
|
||||
|
||||
s.mu.Lock()
|
||||
s.origNameservers = serverAddresses
|
||||
s.mu.Unlock()
|
||||
|
||||
return dnsSettings, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return slices.Clone(s.origNameservers)
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
||||
err := s.addDNSState(key, domains, ip, port, true)
|
||||
if err != nil {
|
||||
|
||||
@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
|
||||
_, err := cmd.CombinedOutput()
|
||||
return err
|
||||
}
|
||||
|
||||
func TestGetOriginalNameservers(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
origNameservers: []netip.Addr{
|
||||
netip.MustParseAddr("8.8.8.8"),
|
||||
netip.MustParseAddr("1.1.1.1"),
|
||||
},
|
||||
}
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
assert.Len(t, servers, 2)
|
||||
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
|
||||
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
|
||||
}
|
||||
|
||||
func TestGetOriginalNameserversFromSystem(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
|
||||
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
|
||||
|
||||
for _, server := range servers {
|
||||
assert.True(t, server.IsValid(), "server address should be valid")
|
||||
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
|
||||
}
|
||||
|
||||
t.Logf("found %d original nameservers: %v", len(servers), servers)
|
||||
}
|
||||
|
||||
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
stateFile := filepath.Join(tmpDir, "state.json")
|
||||
sm := statemanager.New(stateFile)
|
||||
sm.RegisterState(&ShutdownState{})
|
||||
sm.Start()
|
||||
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
cleanup := func() {
|
||||
_ = sm.Stop(context.Background())
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
}
|
||||
|
||||
return configurator, sm, cleanup
|
||||
}
|
||||
|
||||
func TestOriginalNameserversNoTransition(t *testing.T) {
|
||||
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
routeAll bool
|
||||
}{
|
||||
{"routeall_false", false},
|
||||
{"routeall_true", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
initialServers := configurator.getOriginalNameservers()
|
||||
t.Logf("Initial servers: %v", initialServers)
|
||||
require.NotEmpty(t, initialServers)
|
||||
|
||||
for _, srv := range initialServers {
|
||||
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
|
||||
}
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: netbirdIP,
|
||||
ServerPort: 53,
|
||||
RouteAll: tc.routeAll,
|
||||
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||
}
|
||||
|
||||
for i := 1; i <= 2; i++ {
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
|
||||
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initialRoute bool
|
||||
}{
|
||||
{"start_with_routeall_false", false},
|
||||
{"start_with_routeall_true", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
initialServers := configurator.getOriginalNameservers()
|
||||
t.Logf("Initial servers: %v", initialServers)
|
||||
require.NotEmpty(t, initialServers)
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: netbirdIP,
|
||||
ServerPort: 53,
|
||||
RouteAll: tc.initialRoute,
|
||||
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||
}
|
||||
|
||||
// First apply
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers := configurator.getOriginalNameservers()
|
||||
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
// Toggle RouteAll
|
||||
config.RouteAll = !tc.initialRoute
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers = configurator.getOriginalNameservers()
|
||||
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
// Toggle back
|
||||
config.RouteAll = tc.initialRoute
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers = configurator.getOriginalNameservers()
|
||||
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
for _, srv := range servers {
|
||||
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -615,7 +615,7 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
s.registerFallback(config)
|
||||
}
|
||||
|
||||
// registerFallback registers original nameservers as low-priority fallback handlers
|
||||
// registerFallback registers original nameservers as low-priority fallback handlers.
|
||||
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
|
||||
if !ok {
|
||||
@@ -624,6 +624,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
|
||||
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
||||
if len(originalNameservers) == 0 {
|
||||
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -8,15 +8,21 @@ import (
|
||||
|
||||
type MockResponseWriter struct {
|
||||
WriteMsgFunc func(m *dns.Msg) error
|
||||
lastResponse *dns.Msg
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
rw.lastResponse = m
|
||||
if rw.WriteMsgFunc != nil {
|
||||
return rw.WriteMsgFunc(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
|
||||
return rw.lastResponse
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
|
||||
@@ -505,6 +505,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
// Set up notrack rules immediately after proxy is listening to prevent
|
||||
// conntrack entries from being created before the rules are in place
|
||||
e.setupWGProxyNoTrack()
|
||||
|
||||
// Set the WireGuard interface for rosenpass after interface is up
|
||||
if e.rpManager != nil {
|
||||
e.rpManager.SetInterface(e.wgInterface)
|
||||
@@ -569,9 +573,11 @@ func (e *Engine) createFirewall() error {
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil || e.firewall == nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
return nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("create firewall manager: %w", err)
|
||||
}
|
||||
if e.firewall == nil {
|
||||
return fmt.Errorf("create firewall manager: received nil manager")
|
||||
}
|
||||
|
||||
if err := e.initFirewall(); err != nil {
|
||||
@@ -617,6 +623,23 @@ func (e *Engine) initFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic.
|
||||
// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy.
|
||||
func (e *Engine) setupWGProxyNoTrack() {
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
|
||||
proxyPort := e.wgInterface.GetProxyPort()
|
||||
if proxyPort == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil {
|
||||
log.Warnf("failed to setup ebpf proxy notrack: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) blockLanAccess() {
|
||||
if e.config.BlockInbound {
|
||||
// no need to set up extra deny rules if inbound is already blocked in general
|
||||
@@ -1644,6 +1667,7 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
|
||||
func (e *Engine) close() {
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
|
||||
@@ -107,6 +107,7 @@ type MockWGIface struct {
|
||||
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetProxyPortFunc func() uint16
|
||||
GetNetFunc func() *netstack.Net
|
||||
LastActivitiesFunc func() map[string]monotime.Time
|
||||
}
|
||||
@@ -203,6 +204,13 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
return m.GetProxyFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxyPort() uint16 {
|
||||
if m.GetProxyPortFunc != nil {
|
||||
return m.GetProxyPortFunc()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetNet() *netstack.Net {
|
||||
return m.GetNetFunc()
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type wgIfaceBase interface {
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetProxyPort() uint16
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemoveEndpointAddress(key string) error
|
||||
RemovePeer(peerKey string) error
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// IsLoginRequired check that the server is support SSO or not
|
||||
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
|
||||
mgmURL := config.ManagementURL
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||
if isLoginNeeded(err) {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Login or register the client
|
||||
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
return mgmClient, err
|
||||
}
|
||||
|
||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
info.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("peer has been successfully registered on Management Service")
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
@@ -37,6 +38,11 @@ func New() *NetworkMonitor {
|
||||
|
||||
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
||||
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||
if netstack.IsEnabled() {
|
||||
log.Debugf("Network monitor: skipping in netstack mode")
|
||||
return nil
|
||||
}
|
||||
|
||||
nw.mu.Lock()
|
||||
if nw.cancel != nil {
|
||||
nw.mu.Unlock()
|
||||
|
||||
@@ -390,6 +390,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
}
|
||||
|
||||
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
||||
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
||||
conn.handleConfigurationFailure(err, wgProxy)
|
||||
@@ -402,8 +404,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
conn.wgProxyRelay.RedirectAs(ep)
|
||||
}
|
||||
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
@@ -501,6 +501,9 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
wgProxy.Work()
|
||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||
@@ -509,8 +512,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
||||
type PKCEAuthorizationFlow struct {
|
||||
ProviderConfig PKCEAuthProviderConfig
|
||||
}
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
authFlow := PKCEAuthorizationFlow{
|
||||
ProviderConfig: PKCEAuthProviderConfig{
|
||||
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
ClientCertPair: clientCert,
|
||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
|
||||
},
|
||||
}
|
||||
|
||||
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return authFlow, nil
|
||||
}
|
||||
|
||||
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
||||
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
||||
return false, errors.New("not supported on mobile platforms")
|
||||
}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
log.Debugf("Interface monitor: skipped in netstack mode")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if ifaceName == "" {
|
||||
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
||||
return false, errors.New("empty interface name")
|
||||
|
||||
@@ -263,7 +263,14 @@ func (c *Client) IsLoginRequired() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
|
||||
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("IsLoginRequired: failed to create auth client: %v", err)
|
||||
return true // Assume login is required if we can't create auth client
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("IsLoginRequired: check failed: %v", err)
|
||||
// If the check fails, assume login is required to be safe
|
||||
@@ -314,16 +321,19 @@ func (c *Client) LoginForMobile() string {
|
||||
|
||||
// This could cause a potential race condition with loading the extension which need to be handled on swift side
|
||||
go func() {
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
|
||||
if err != nil {
|
||||
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
|
||||
return
|
||||
}
|
||||
jwtToken := tokenInfo.GetTokenToUse()
|
||||
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
|
||||
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("LoginForMobile: failed to create auth client: %v", err)
|
||||
return
|
||||
}
|
||||
defer authClient.Close()
|
||||
if err, _ := authClient.Login(ctx, "", jwtToken); err != nil {
|
||||
log.Errorf("LoginForMobile: Login failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -7,13 +7,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -90,34 +85,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
}
|
||||
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||
}
|
||||
|
||||
if !supportsSSO {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||
// which are blocked by the tvOS sandbox in App Group containers
|
||||
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
||||
@@ -141,19 +123,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
||||
}
|
||||
|
||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
}
|
||||
return backoffErr
|
||||
})
|
||||
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||
@@ -164,15 +144,16 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
||||
// LoginSync performs a synchronous login check without UI interaction
|
||||
// Used for background VPN connection where user should already be authenticated
|
||||
func (a *Auth) LoginSync() error {
|
||||
var needsLogin bool
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||
return
|
||||
})
|
||||
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
@@ -180,15 +161,12 @@ func (a *Auth) LoginSync() error {
|
||||
return fmt.Errorf("not authenticated")
|
||||
}
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
err, isAuthError := authClient.Login(a.ctx, "", jwtToken)
|
||||
if err != nil {
|
||||
if isAuthError {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return fmt.Errorf("authentication error: %v", err)
|
||||
}
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -225,8 +203,6 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
|
||||
var needsLogin bool
|
||||
|
||||
// Create context with device name if provided
|
||||
ctx := a.ctx
|
||||
if deviceName != "" {
|
||||
@@ -234,33 +210,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
||||
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
}
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
|
||||
return
|
||||
})
|
||||
authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
// check if we need to generate JWT token
|
||||
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err = a.withBackOff(ctx, func() error {
|
||||
err := internal.Login(ctx, a.config, "", jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
err, isAuthError := authClient.Login(ctx, "", jwtToken)
|
||||
if err != nil {
|
||||
if isAuthError {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return fmt.Errorf("authentication error: %v", err)
|
||||
}
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -285,10 +261,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
||||
|
||||
const authInfoRequestTimeout = 30 * time.Second
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
|
||||
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
||||
@@ -313,15 +289,6 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool)
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||
return backoff.RetryNotify(
|
||||
bf,
|
||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||
})
|
||||
}
|
||||
|
||||
// GetConfigJSON returns the current config as a JSON string.
|
||||
// This can be used by the caller to persist the config via alternative storage
|
||||
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||
|
||||
@@ -253,10 +253,17 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
|
||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||
var status internal.StatusType
|
||||
err := internal.Login(ctx, s.config, setupKey, jwtToken)
|
||||
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
|
||||
if err != nil {
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
log.Errorf("failed to create auth client: %v", err)
|
||||
return internal.StatusLoginFailed, err
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
var status internal.StatusType
|
||||
err, isAuthError := authClient.Login(ctx, setupKey, jwtToken)
|
||||
if err != nil {
|
||||
if isAuthError {
|
||||
log.Warnf("failed login: %v", err)
|
||||
status = internal.StatusNeedsLogin
|
||||
} else {
|
||||
@@ -581,8 +588,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
s.oauthAuthFlow.waitCancel()
|
||||
}
|
||||
|
||||
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
|
||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
||||
waitCTX, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
s.mutex.Lock()
|
||||
|
||||
@@ -207,8 +207,6 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
||||
}
|
||||
|
||||
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
|
||||
// Create a backend session to mirror the client's session request.
|
||||
// This keeps the connection alive on the server side while port forwarding channels operate.
|
||||
serverSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||
@@ -216,10 +214,28 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c
|
||||
}
|
||||
defer func() { _ = serverSession.Close() }()
|
||||
|
||||
<-session.Context().Done()
|
||||
serverSession.Stdin = session
|
||||
serverSession.Stdout = session
|
||||
serverSession.Stderr = session.Stderr()
|
||||
|
||||
if err := session.Exit(0); err != nil {
|
||||
log.Debugf("session exit: %v", err)
|
||||
if err := serverSession.Shell(); err != nil {
|
||||
log.Debugf("start shell: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- serverSession.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-session.Context().Done():
|
||||
return
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
log.Debugf("shell session: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handleCommand executes an SSH command with privilege validation
|
||||
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
|
||||
// handleExecution executes an SSH command or shell with privilege validation
|
||||
func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
|
||||
hasPty := winCh != nil
|
||||
|
||||
commandType := "command"
|
||||
@@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
||||
|
||||
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
||||
|
||||
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
|
||||
execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty)
|
||||
if err != nil {
|
||||
logger.Errorf("%s creation failed: %v", commandType, err)
|
||||
|
||||
@@ -51,13 +51,12 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
||||
|
||||
defer cleanup()
|
||||
|
||||
ptyReq, _, _ := session.Pty()
|
||||
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
||||
logger.Debugf("%s execution completed", commandType)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
return nil, nil, errors.New("no user in privilege result")
|
||||
@@ -66,28 +65,28 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh
|
||||
// If PTY requested but su doesn't support --pty, skip su and use executor
|
||||
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
||||
if hasPty && !s.suSupportsPty {
|
||||
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
// Try su first for system integration (PAM/audit) when privileged
|
||||
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
||||
cmd, err := s.createSuCommand(logger, session, localUser, hasPty)
|
||||
if err != nil || privilegeResult.UsedFallback {
|
||||
log.Debugf("su command failed, falling back to executor: %v", err)
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
logger.Debugf("su command failed, falling back to executor: %v", err)
|
||||
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||
return cmd, func() {}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,17 +15,17 @@ import (
|
||||
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
||||
|
||||
// createSuCommand is not supported on JS/WASM
|
||||
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||
func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||
return nil, errNotSupported
|
||||
}
|
||||
|
||||
// createExecutorCommand is not supported on JS/WASM
|
||||
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||
func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||
return nil, nil, errNotSupported
|
||||
}
|
||||
|
||||
// prepareCommandEnv is not supported on JS/WASM
|
||||
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
|
||||
func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -99,40 +100,52 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
|
||||
return isUtilLinux
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching
|
||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
// createSuCommand creates a command using su - for privilege switching.
|
||||
func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
if err := validateUsername(localUser.Username); err != nil {
|
||||
return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||
}
|
||||
|
||||
suPath, err := exec.LookPath("su")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("su command not available: %w", err)
|
||||
}
|
||||
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("no command specified for su execution")
|
||||
}
|
||||
|
||||
args := []string{"-l"}
|
||||
args := []string{"-"}
|
||||
if hasPty && s.suSupportsPty {
|
||||
args = append(args, "--pty")
|
||||
}
|
||||
args = append(args, localUser.Username, "-c", command)
|
||||
args = append(args, localUser.Username)
|
||||
|
||||
command := session.RawCommand()
|
||||
if command != "" {
|
||||
args = append(args, "-c", command)
|
||||
}
|
||||
|
||||
logger.Debugf("creating su command: %s %v", suPath, args)
|
||||
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string.
|
||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
if cmdString == "" {
|
||||
return []string{shell, "-l"}
|
||||
return []string{shell}
|
||||
}
|
||||
return []string{shell, "-l", "-c", cmdString}
|
||||
return []string{shell, "-c", cmdString}
|
||||
}
|
||||
|
||||
// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname".
|
||||
func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, shell, args[1:]...)
|
||||
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
@@ -154,7 +167,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e
|
||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
||||
if err != nil {
|
||||
logger.Errorf("Pty command creation failed: %v", err)
|
||||
@@ -244,11 +257,6 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Debugf("session close error: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(session, ptmx); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||
logger.Warnf("Pty output copy error: %v", err)
|
||||
@@ -268,7 +276,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex
|
||||
case <-ctx.Done():
|
||||
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
||||
case err := <-done:
|
||||
s.handlePtyCommandCompletion(logger, session, err)
|
||||
s.handlePtyCommandCompletion(logger, session, ptyMgr, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,17 +304,20 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
|
||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) {
|
||||
if err != nil {
|
||||
logger.Debugf("Pty command execution failed: %v", err)
|
||||
s.handleSessionExit(session, err, logger)
|
||||
return
|
||||
} else {
|
||||
logger.Debugf("Pty command completed successfully")
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Normal completion
|
||||
logger.Debugf("Pty command completed successfully")
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
// Close PTY to unblock io.Copy goroutines
|
||||
if err := ptyMgr.Close(); err != nil {
|
||||
logger.Debugf("Pty close after completion: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,32 +20,32 @@ import (
|
||||
|
||||
// getUserEnvironment retrieves the Windows environment for the target user.
|
||||
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
||||
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
||||
userToken, err := s.getUserToken(username, domain)
|
||||
func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) {
|
||||
userToken, err := s.getUserToken(logger, username, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
logger.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s.getUserEnvironmentWithToken(userToken, username, domain)
|
||||
return s.getUserEnvironmentWithToken(logger, userToken, username, domain)
|
||||
}
|
||||
|
||||
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
||||
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
|
||||
func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) {
|
||||
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||
logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
||||
}
|
||||
|
||||
envMap := make(map[string]string)
|
||||
|
||||
if err := s.loadSystemEnvironment(envMap); err != nil {
|
||||
log.Debugf("failed to load system environment from registry: %v", err)
|
||||
logger.Debugf("failed to load system environment from registry: %v", err)
|
||||
}
|
||||
|
||||
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
||||
@@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username,
|
||||
}
|
||||
|
||||
// getUserToken creates a user token for the specified user.
|
||||
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
token, err := privilegeDropper.createToken(username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
||||
@@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Windows
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
userEnv, err := s.getUserEnvironment(username, domain)
|
||||
userEnv, err := s.getUserEnvironment(logger, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
@@ -267,22 +267,16 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []
|
||||
return env
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
if privilegeResult.User == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
return false
|
||||
}
|
||||
|
||||
cmd := session.Command()
|
||||
shell := getUserShell(privilegeResult.User.Uid)
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
|
||||
if len(cmd) == 0 {
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
} else {
|
||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
||||
}
|
||||
|
||||
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
||||
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -294,11 +288,6 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
return []string{shell, "-Command", cmdString}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
||||
logger.Info("starting interactive shell")
|
||||
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
|
||||
}
|
||||
|
||||
type PtyExecutionRequest struct {
|
||||
Shell string
|
||||
Command string
|
||||
@@ -308,25 +297,25 @@ type PtyExecutionRequest struct {
|
||||
Domain string
|
||||
}
|
||||
|
||||
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
|
||||
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||
func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error {
|
||||
logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
||||
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
logger.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
server := &Server{}
|
||||
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
|
||||
userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||
logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||
userEnv = os.Environ()
|
||||
}
|
||||
|
||||
@@ -348,8 +337,8 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re
|
||||
Environment: userEnv,
|
||||
}
|
||||
|
||||
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
|
||||
logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||
return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig)
|
||||
}
|
||||
|
||||
func getUserHomeFromEnv(env []string) string {
|
||||
@@ -371,10 +360,8 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithField("pid", cmd.Process.Pid)
|
||||
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
logger.Debugf("kill process failed: %v", err)
|
||||
log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -389,21 +376,7 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
||||
}
|
||||
|
||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
logger.Error("no command specified for PTY execution")
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
|
||||
}
|
||||
|
||||
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
|
||||
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
@@ -415,14 +388,14 @@ func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, pr
|
||||
|
||||
req := PtyExecutionRequest{
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Command: session.RawCommand(),
|
||||
Width: ptyReq.Window.Width,
|
||||
Height: ptyReq.Window.Height,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
}
|
||||
|
||||
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
|
||||
if err := executePtyCommandWithUserToken(logger, session, req); err != nil {
|
||||
logger.Errorf("ConPty execution failed: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
|
||||
@@ -4,12 +4,15 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -23,25 +26,67 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
// On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server
|
||||
// spawns an executor subprocess via os.Executable(). During tests, this invokes the test
|
||||
// binary with "ssh exec" args. We handle that here to properly execute commands and
|
||||
// propagate exit codes.
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
runTestExecutor()
|
||||
return
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// runTestExecutor emulates the netbird executor for tests.
|
||||
// Parses --shell and --cmd args, runs the command, and exits with the correct code.
|
||||
func runTestExecutor() {
|
||||
if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" {
|
||||
fmt.Fprintf(os.Stderr, "executor recursion detected\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Setenv("_NETBIRD_TEST_EXECUTOR", "1")
|
||||
|
||||
shell := "/bin/sh"
|
||||
var command string
|
||||
for i := 3; i < len(os.Args); i++ {
|
||||
switch os.Args[i] {
|
||||
case "--shell":
|
||||
if i+1 < len(os.Args) {
|
||||
shell = os.Args[i+1]
|
||||
i++
|
||||
}
|
||||
case "--cmd":
|
||||
if i+1 < len(os.Args) {
|
||||
command = os.Args[i+1]
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if command == "" {
|
||||
cmd = exec.Command(shell)
|
||||
} else {
|
||||
cmd = exec.Command(shell, "-c", command)
|
||||
}
|
||||
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
os.Exit(exitErr.ExitCode())
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
||||
func TestSSHServerCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
@@ -405,6 +450,171 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
|
||||
return createTempKeyFileFromBytes(t, privateKey)
|
||||
}
|
||||
|
||||
// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags)
|
||||
// This ensures our implementation matches OpenSSH behavior for:
|
||||
// - ssh host command (no PTY - default when no TTY)
|
||||
// - ssh -T host command (explicit no PTY)
|
||||
// - ssh -t host command (force PTY)
|
||||
// - ssh -T host (no PTY shell - our implementation)
|
||||
func TestSSHPtyModes(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH PTY mode tests in short mode")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
baseArgs := []string{
|
||||
"-i", clientKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-o", "BatchMode=yes",
|
||||
}
|
||||
|
||||
t.Run("command_default_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (default no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "no_pty_default")
|
||||
})
|
||||
|
||||
t.Run("command_explicit_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "explicit_no_pty")
|
||||
})
|
||||
|
||||
t.Run("command_force_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (-tt force PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "force_pty")
|
||||
})
|
||||
|
||||
t.Run("shell_explicit_no_pty", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host))
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed")
|
||||
|
||||
go func() {
|
||||
defer stdin.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_, err := stdin.Write([]byte("echo shell_no_pty_test\n"))
|
||||
assert.NoError(t, err, "write echo command")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_, err = stdin.Write([]byte("exit 0\n"))
|
||||
assert.NoError(t, err, "write exit command")
|
||||
}()
|
||||
|
||||
output, _ := io.ReadAll(stdout)
|
||||
err = cmd.Wait()
|
||||
|
||||
require.NoError(t, err, "Shell (-T no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "shell_no_pty_test")
|
||||
})
|
||||
|
||||
t.Run("exit_code_preserved_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
err := cmd.Run()
|
||||
require.Error(t, err, "Command should exit with non-zero")
|
||||
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||
assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T")
|
||||
})
|
||||
|
||||
t.Run("exit_code_preserved_with_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
err := cmd.Run()
|
||||
require.Error(t, err, "PTY command should exit with non-zero")
|
||||
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||
assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt")
|
||||
})
|
||||
|
||||
t.Run("stderr_works_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host),
|
||||
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
var stdout, stderr strings.Builder
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
require.NoError(t, cmd.Run(), "stderr test failed")
|
||||
assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg")
|
||||
assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg")
|
||||
assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg")
|
||||
})
|
||||
|
||||
t.Run("stderr_merged_with_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host),
|
||||
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "PTY stderr test failed: %s", output)
|
||||
assert.Contains(t, string(output), "stdout_msg")
|
||||
assert.Contains(t, string(output), "stderr_msg")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
||||
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
@@ -35,11 +36,35 @@ type ExecutorConfig struct {
|
||||
}
|
||||
|
||||
// PrivilegeDropper handles secure privilege dropping in child processes
|
||||
type PrivilegeDropper struct{}
|
||||
type PrivilegeDropper struct {
|
||||
logger *log.Entry
|
||||
}
|
||||
|
||||
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||
|
||||
// NewPrivilegeDropper creates a new privilege dropper
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||
pd := &PrivilegeDropper{}
|
||||
for _, opt := range opts {
|
||||
opt(pd)
|
||||
}
|
||||
return pd
|
||||
}
|
||||
|
||||
// WithLogger sets the logger for the PrivilegeDropper
|
||||
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||
return func(pd *PrivilegeDropper) {
|
||||
pd.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// log returns the logger, falling back to standard logger if none set
|
||||
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||
if pd.logger != nil {
|
||||
return pd.logger
|
||||
}
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
|
||||
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
||||
@@ -83,7 +108,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||
pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
||||
}
|
||||
|
||||
@@ -206,17 +231,22 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config
|
||||
|
||||
var execCmd *exec.Cmd
|
||||
if config.Command == "" {
|
||||
os.Exit(ExitCodeSuccess)
|
||||
execCmd = exec.CommandContext(ctx, config.Shell)
|
||||
} else {
|
||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||
}
|
||||
|
||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||
execCmd.Args[0] = "-" + filepath.Base(config.Shell)
|
||||
execCmd.Stdin = os.Stdin
|
||||
execCmd.Stdout = os.Stdout
|
||||
execCmd.Stderr = os.Stderr
|
||||
|
||||
cmdParts := strings.Fields(config.Command)
|
||||
safeCmd := safeLogCommand(cmdParts)
|
||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||
if config.Command == "" {
|
||||
log.Tracef("executing login shell: %s", execCmd.Path)
|
||||
} else {
|
||||
cmdParts := strings.Fields(config.Command)
|
||||
safeCmd := safeLogCommand(cmdParts)
|
||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||
}
|
||||
if err := execCmd.Run(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
|
||||
@@ -28,22 +28,45 @@ const (
|
||||
)
|
||||
|
||||
type WindowsExecutorConfig struct {
|
||||
Username string
|
||||
Domain string
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
Args []string
|
||||
Interactive bool
|
||||
Pty bool
|
||||
PtyWidth int
|
||||
PtyHeight int
|
||||
Username string
|
||||
Domain string
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
Args []string
|
||||
Pty bool
|
||||
PtyWidth int
|
||||
PtyHeight int
|
||||
}
|
||||
|
||||
type PrivilegeDropper struct{}
|
||||
type PrivilegeDropper struct {
|
||||
logger *log.Entry
|
||||
}
|
||||
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||
|
||||
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||
pd := &PrivilegeDropper{}
|
||||
for _, opt := range opts {
|
||||
opt(pd)
|
||||
}
|
||||
return pd
|
||||
}
|
||||
|
||||
// WithLogger sets the logger for the PrivilegeDropper
|
||||
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||
return func(pd *PrivilegeDropper) {
|
||||
pd.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// log returns the logger, falling back to standard logger if none set
|
||||
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||
if pd.logger != nil {
|
||||
return pd.logger
|
||||
}
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -56,7 +79,6 @@ const (
|
||||
|
||||
// Common error messages
|
||||
commandFlag = "-Command"
|
||||
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
|
||||
convertUsernameError = "convert username to UTF16: %w"
|
||||
convertDomainError = "convert domain to UTF16: %w"
|
||||
)
|
||||
@@ -80,7 +102,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
|
||||
shellArgs = []string{shell}
|
||||
}
|
||||
|
||||
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||
pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||
|
||||
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
||||
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
||||
@@ -180,10 +202,10 @@ func newLsaString(s string) lsaString {
|
||||
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
||||
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||
userCpn := buildUserCpn(username, domain)
|
||||
|
||||
pd := NewPrivilegeDropper()
|
||||
pd := NewPrivilegeDropper(WithLogger(logger))
|
||||
isDomainUser := !pd.isLocalUser(domain)
|
||||
|
||||
lsaHandle, err := initializeLsaConnection()
|
||||
@@ -197,12 +219,12 @@ func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
|
||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||
return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||
}
|
||||
|
||||
// buildUserCpn constructs the user principal name
|
||||
@@ -310,21 +332,21 @@ func lookupPrincipalName(username, domain string) (string, error) {
|
||||
}
|
||||
|
||||
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
||||
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||
func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||
if isDomainUser {
|
||||
return prepareDomainS4ULogon(username, domain)
|
||||
return prepareDomainS4ULogon(logger, username, domain)
|
||||
}
|
||||
return prepareLocalS4ULogon(username)
|
||||
return prepareLocalS4ULogon(logger, username)
|
||||
}
|
||||
|
||||
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
||||
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||
func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||
upn, err := lookupPrincipalName(username, domain)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||
logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||
|
||||
upnUtf16, err := windows.UTF16FromString(upn)
|
||||
if err != nil {
|
||||
@@ -357,8 +379,8 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er
|
||||
}
|
||||
|
||||
// prepareLocalS4ULogon creates S4U logon structure for local users
|
||||
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
||||
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||
func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) {
|
||||
logger.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||
|
||||
usernameUtf16, err := windows.UTF16FromString(username)
|
||||
if err != nil {
|
||||
@@ -406,11 +428,11 @@ func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
||||
}
|
||||
|
||||
// performS4ULogon executes the S4U logon operation
|
||||
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||
func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||
var tokenSource tokenSource
|
||||
copy(tokenSource.SourceName[:], "netbird")
|
||||
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
||||
log.Debugf("AllocateLocallyUniqueId failed")
|
||||
logger.Debugf("AllocateLocallyUniqueId failed")
|
||||
}
|
||||
|
||||
originName := newLsaString("netbird")
|
||||
@@ -441,7 +463,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
|
||||
|
||||
if profile != 0 {
|
||||
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
||||
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||
logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -449,7 +471,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
|
||||
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
||||
}
|
||||
|
||||
log.Debugf("created S4U %s token for user %s",
|
||||
logger.Debugf("created S4U %s token for user %s",
|
||||
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
||||
return token, nil
|
||||
}
|
||||
@@ -497,8 +519,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
|
||||
|
||||
// authenticateLocalUser handles authentication for local users
|
||||
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
||||
log.Debugf("using S4U authentication for local user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, ".")
|
||||
pd.log().Debugf("using S4U authentication for local user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(pd.log(), username, ".")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
||||
}
|
||||
@@ -507,12 +529,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string)
|
||||
|
||||
// authenticateDomainUser handles authentication for domain users
|
||||
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
||||
log.Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, domain)
|
||||
pd.log().Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(pd.log(), username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
||||
}
|
||||
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
|
||||
pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
@@ -526,7 +548,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec
|
||||
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(token); err != nil {
|
||||
log.Debugf("close impersonation token: %v", err)
|
||||
pd.log().Debugf("close impersonation token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -564,7 +586,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
|
||||
return cmd, primaryToken, nil
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
|
||||
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||
// createSuCommand creates a command using su - for privilege switching (Windows stub).
|
||||
func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||
return nil, fmt.Errorf("su command not available on Windows")
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -88,7 +88,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
serverNoJWT.SetAllowRootLogin(true)
|
||||
|
||||
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
||||
defer require.NoError(t, serverNoJWT.Stop())
|
||||
defer func() { require.NoError(t, serverNoJWT.Stop()) }()
|
||||
|
||||
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
||||
require.NoError(t, err)
|
||||
@@ -213,7 +213,7 @@ func TestJWTDetection(t *testing.T) {
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -341,7 +341,7 @@ func TestJWTFailClose(t *testing.T) {
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -596,7 +596,7 @@ func TestJWTAuthentication(t *testing.T) {
|
||||
server.UpdateSSHAuth(authConfig)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -715,7 +715,7 @@ func TestJWTMultipleAudiences(t *testing.T) {
|
||||
server.UpdateSSHAuth(authConfig)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -271,13 +271,6 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
|
||||
return s.allowRemotePortForwarding
|
||||
}
|
||||
|
||||
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
|
||||
func (s *Server) isPortForwardingEnabled() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
|
||||
}
|
||||
|
||||
// parseTcpipForwardRequest parses the SSH request payload
|
||||
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
||||
var payload tcpipForwardMsg
|
||||
|
||||
@@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
||||
sessions = append(sessions, info)
|
||||
}
|
||||
|
||||
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
|
||||
// Add authenticated connections without sessions (e.g., -N or port-forwarding only)
|
||||
for key, connState := range s.connections {
|
||||
remoteAddr := string(key)
|
||||
if reportedAddrs[remoteAddr] {
|
||||
|
||||
@@ -483,12 +483,11 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_PortForwardingOnlySession(t *testing.T) {
|
||||
// Test that sessions without PTY and command are allowed when port forwarding is enabled
|
||||
func TestServer_NonPtyShellSession(t *testing.T) {
|
||||
// Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings.
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -496,36 +495,26 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
||||
name string
|
||||
allowLocalForwarding bool
|
||||
allowRemoteForwarding bool
|
||||
expectAllowed bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "session_allowed_with_local_forwarding",
|
||||
name: "shell_with_local_forwarding_enabled",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: false,
|
||||
expectAllowed: true,
|
||||
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
|
||||
},
|
||||
{
|
||||
name: "session_allowed_with_remote_forwarding",
|
||||
name: "shell_with_remote_forwarding_enabled",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: true,
|
||||
expectAllowed: true,
|
||||
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
|
||||
},
|
||||
{
|
||||
name: "session_allowed_with_both",
|
||||
name: "shell_with_both_forwarding_enabled",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: true,
|
||||
expectAllowed: true,
|
||||
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
|
||||
},
|
||||
{
|
||||
name: "session_denied_without_forwarding",
|
||||
name: "shell_with_forwarding_disabled",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: false,
|
||||
expectAllowed: false,
|
||||
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -545,7 +534,6 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
||||
_ = server.Stop()
|
||||
}()
|
||||
|
||||
// Connect to the server without requesting PTY or command
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -557,20 +545,10 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
||||
_ = client.Close()
|
||||
}()
|
||||
|
||||
// Execute a command without PTY - this simulates ssh -T with no command
|
||||
// The server should either allow it (port forwarding enabled) or reject it
|
||||
output, err := client.ExecuteCommand(ctx, "")
|
||||
if tt.expectAllowed {
|
||||
// When allowed, the session stays open until cancelled
|
||||
// ExecuteCommand with empty command should return without error
|
||||
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
|
||||
assert.NotContains(t, output, "port forwarding is disabled",
|
||||
"Output should not contain port forwarding disabled message")
|
||||
} else if err != nil {
|
||||
// When denied, we expect an error message about port forwarding being disabled
|
||||
assert.Contains(t, err.Error(), "port forwarding is disabled",
|
||||
"Should get port forwarding disabled message")
|
||||
}
|
||||
// Execute without PTY and no command - simulates ssh -T (shell without PTY)
|
||||
// Should always succeed regardless of port forwarding settings
|
||||
_, err = client.ExecuteCommand(ctx, "")
|
||||
assert.NoError(t, err, "Non-PTY shell session should be allowed")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,12 +405,14 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) {
|
||||
assert.Equal(t, "-Command", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
} else {
|
||||
// Test Unix shell behavior
|
||||
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
||||
assert.Equal(t, "/bin/sh", args[0])
|
||||
assert.Equal(t, "-l", args[1])
|
||||
assert.Equal(t, "-c", args[2])
|
||||
assert.Equal(t, "echo test", args[3])
|
||||
assert.Equal(t, "-c", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
|
||||
args = server.getShellCommandArgs("/bin/sh", "")
|
||||
assert.Equal(t, "/bin/sh", args[0])
|
||||
assert.Len(t, args, 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -62,54 +62,12 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
|
||||
switch {
|
||||
case isPty && hasCommand:
|
||||
// ssh -t <host> <cmd> - Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, winCh)
|
||||
case isPty:
|
||||
// ssh <host> - Pty interactive session (login)
|
||||
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
|
||||
case hasCommand:
|
||||
// ssh <host> <cmd> - non-Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, nil)
|
||||
default:
|
||||
// ssh -T (or ssh -N) - no PTY, no command
|
||||
s.handleNonInteractiveSession(logger, session)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonInteractiveSession handles sessions that have no PTY and no command.
|
||||
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
|
||||
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
|
||||
s.updateSessionType(session, cmdNonInteractive)
|
||||
|
||||
if !s.isPortForwardingEnabled() {
|
||||
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
logger.Infof("rejected non-interactive session: port forwarding disabled")
|
||||
return
|
||||
}
|
||||
|
||||
<-session.Context().Done()
|
||||
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, state := range s.sessions {
|
||||
if state.session == session {
|
||||
state.sessionType = sessionType
|
||||
return
|
||||
}
|
||||
if isPty && !hasCommand {
|
||||
// ssh <host> - PTY interactive session (login)
|
||||
s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh)
|
||||
} else {
|
||||
// ssh <host> <cmd>, ssh -t <host> <cmd>, ssh -T <host> - command or shell execution
|
||||
s.handleExecution(logger, session, privilegeResult, ptyReq, winCh)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handlePty is not supported on JS/WASM
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
// handlePtyLogin is not supported on JS/WASM
|
||||
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
||||
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
|
||||
@@ -8,19 +8,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// StartTestServer starts the SSH server and returns the address it's listening on.
|
||||
func StartTestServer(t *testing.T, server *Server) string {
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
// Use port 0 to let the OS assign a free port
|
||||
addrPort := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Get the actual listening address from the server
|
||||
actualAddr := server.Addr()
|
||||
if actualAddr == nil {
|
||||
errChan <- fmt.Errorf("server started but no listener address available")
|
||||
|
||||
@@ -181,8 +181,8 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
|
||||
|
||||
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
|
||||
// Returns the command and a cleanup function (no-op on Unix).
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
logger.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
if err := validateUsername(localUser.Username); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||
@@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
|
||||
}
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
config := ExecutorConfig{
|
||||
UID: uid,
|
||||
GID: gid,
|
||||
@@ -233,7 +233,7 @@ func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.Use
|
||||
shell := getUserShell(localUser.Uid)
|
||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||
|
||||
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
|
||||
cmd := s.createShellCommand(session.Context(), shell, args)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
||||
|
||||
|
||||
@@ -88,20 +88,20 @@ func validateUsernameFormat(username string) error {
|
||||
|
||||
// createExecutorCommand creates a command using Windows executor for privilege dropping.
|
||||
// Returns the command and a cleanup function that must be called after starting the process.
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
logger.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
username, _ := s.parseUsername(localUser.Username)
|
||||
if err := validateUsername(username); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
|
||||
}
|
||||
|
||||
return s.createUserSwitchCommand(localUser, session, hasPty)
|
||||
return s.createUserSwitchCommand(logger, session, localUser)
|
||||
}
|
||||
|
||||
// createUserSwitchCommand creates a command with Windows user switching.
|
||||
// Returns the command and a cleanup function that must be called after starting the process.
|
||||
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
|
||||
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
@@ -113,15 +113,14 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
|
||||
}
|
||||
|
||||
config := WindowsExecutorConfig{
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Interactive: interactive || (rawCmd == ""),
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
}
|
||||
|
||||
dropper := NewPrivilegeDropper()
|
||||
dropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -130,7 +129,7 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
|
||||
cleanup := func() {
|
||||
if token != 0 {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
log.Debugf("close primary token: %v", err)
|
||||
logger.Debugf("close primary token: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ var (
|
||||
)
|
||||
|
||||
// ExecutePtyWithUserToken executes a command with ConPty using user token.
|
||||
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
||||
func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
||||
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
|
||||
commandLine := buildCommandLine(args)
|
||||
|
||||
@@ -64,7 +64,7 @@ func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig
|
||||
Pty: ptyConfig,
|
||||
User: userConfig,
|
||||
Session: session,
|
||||
Context: ctx,
|
||||
Context: session.Context(),
|
||||
}
|
||||
|
||||
return executeConPtyWithConfig(commandLine, config)
|
||||
|
||||
@@ -63,6 +63,8 @@ func (h *eventHandler) listen(ctx context.Context) {
|
||||
h.handleNetworksClick()
|
||||
case <-h.client.mNotifications.ClickedCh:
|
||||
h.handleNotificationsClick()
|
||||
case <-systray.TrayOpenedCh:
|
||||
h.client.updateExitNodes()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,7 +341,6 @@ func (s *serviceClient) updateExitNodes() {
|
||||
log.Errorf("get client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
exitNodes, err := s.getExitNodes(conn)
|
||||
if err != nil {
|
||||
log.Errorf("get exit nodes: %v", err)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -31,7 +31,7 @@ require (
|
||||
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.7.0
|
||||
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.36.3
|
||||
|
||||
4
go.sum
4
go.sum
@@ -13,8 +13,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ=
|
||||
fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE=
|
||||
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 h1:eA5/u2XRd8OUkoMqEv3IBlFYSruNlXD8bRHDiqm0VNI=
|
||||
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4=
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
|
||||
|
||||
@@ -327,6 +327,60 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasNonLocalConnectors checks if there are any connectors other than the local connector.
|
||||
func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) {
|
||||
connectors, err := p.storage.ListConnectors(ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to list connectors: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors))
|
||||
for _, conn := range connectors {
|
||||
p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name)
|
||||
if conn.ID != "local" || conn.Type != "local" {
|
||||
p.logger.Info("found non-local connector", "id", conn.ID)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
p.logger.Info("no non-local connectors found")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// DisableLocalAuth removes the local (password) connector.
|
||||
// Returns an error if no other connectors are configured.
|
||||
func (p *Provider) DisableLocalAuth(ctx context.Context) error {
|
||||
hasOthers, err := p.HasNonLocalConnectors(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasOthers {
|
||||
return fmt.Errorf("cannot disable local authentication: no other identity providers configured")
|
||||
}
|
||||
|
||||
// Check if local connector exists
|
||||
_, err = p.storage.GetConnector(ctx, "local")
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
// Already disabled
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check local connector: %w", err)
|
||||
}
|
||||
|
||||
// Delete the local connector
|
||||
if err := p.storage.DeleteConnector(ctx, "local"); err != nil {
|
||||
return fmt.Errorf("failed to delete local connector: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Info("local authentication disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableLocalAuth creates the local (password) connector if it doesn't exist.
|
||||
func (p *Provider) EnableLocalAuth(ctx context.Context) error {
|
||||
return ensureLocalConnector(ctx, p.storage)
|
||||
}
|
||||
|
||||
// ensureStaticConnectors creates or updates static connectors in storage
|
||||
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
|
||||
for _, conn := range connectors {
|
||||
|
||||
@@ -16,13 +16,13 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
"github.com/netbirdio/netbird/management/internals/server"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
@@ -78,9 +78,8 @@ var (
|
||||
}
|
||||
}
|
||||
|
||||
_, valid := dns.IsDomainName(dnsDomain)
|
||||
if !valid || len(dnsDomain) > 192 {
|
||||
return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Length: %d", valid, len(dnsDomain))
|
||||
if !nbdomain.IsValidDomainNoWildcard(dnsDomain) {
|
||||
return fmt.Errorf("invalid dns-domain: %s", dnsDomain)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -187,10 +187,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
}
|
||||
|
||||
for accountID, peerIDs := range peerIDsPerAccount {
|
||||
log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID)
|
||||
log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID)
|
||||
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,10 +108,19 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
if e, ok := status.FromError(err); ok && e.Type() == status.NotFound {
|
||||
log.WithContext(ctx).Tracef("DeletePeers: peer %s not found, skipping", peerID)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) {
|
||||
log.WithContext(ctx).Tracef("DeletePeers: peer %s skipped (connected=%t, lastSeen=%s, threshold=%s, ephemeral=%t)",
|
||||
peerID, peer.Status.Connected,
|
||||
peer.Status.LastSeen.Format(time.RFC3339),
|
||||
time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)).Format(time.RFC3339),
|
||||
peer.Ephemeral)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -150,7 +159,8 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
log.WithContext(ctx).Errorf("DeletePeers: failed to delete peer %s: %v", peerID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if m.integratedPeerValidator != nil {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
@@ -63,7 +63,7 @@ func (r *Record) Validate() error {
|
||||
return errors.New("record name is required")
|
||||
}
|
||||
|
||||
if !util.IsValidDomain(r.Name) {
|
||||
if !domain.IsValidDomain(r.Name) {
|
||||
return errors.New("invalid record name format")
|
||||
}
|
||||
|
||||
@@ -81,8 +81,8 @@ func (r *Record) Validate() error {
|
||||
return err
|
||||
}
|
||||
case RecordTypeCNAME:
|
||||
if !util.IsValidDomain(r.Content) {
|
||||
return errors.New("invalid CNAME record format")
|
||||
if !domain.IsValidDomainNoWildcard(r.Content) {
|
||||
return errors.New("invalid CNAME target format")
|
||||
}
|
||||
default:
|
||||
return errors.New("invalid record type, must be A, AAAA, or CNAME")
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ func (z *Zone) Validate() error {
|
||||
return errors.New("zone name exceeds maximum length of 255 characters")
|
||||
}
|
||||
|
||||
if !util.IsValidDomain(z.Domain) {
|
||||
if !domain.IsValidDomainNoWildcard(z.Domain) {
|
||||
return errors.New("invalid zone domain format")
|
||||
}
|
||||
|
||||
|
||||
@@ -69,7 +69,14 @@ func (s *BaseServer) UsersManager() users.Manager {
|
||||
func (s *BaseServer) SettingsManager() settings.Manager {
|
||||
return Create(s, func() settings.Manager {
|
||||
extraSettingsManager := integrations.NewManager(s.EventStore())
|
||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
|
||||
|
||||
idpConfig := settings.IdpConfig{}
|
||||
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||
idpConfig.EmbeddedIdpEnabled = true
|
||||
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
|
||||
}
|
||||
|
||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -17,13 +17,14 @@ import (
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
@@ -76,8 +77,9 @@ type Server struct {
|
||||
|
||||
oAuthConfigProvider idp.OAuthConfigProvider
|
||||
|
||||
syncSem atomic.Int32
|
||||
syncLim int32
|
||||
syncSem atomic.Int32
|
||||
syncLimEnabled bool
|
||||
syncLim int32
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -107,6 +109,7 @@ func NewServer(
|
||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||
|
||||
syncLim := int32(defaultSyncLim)
|
||||
syncLimEnabled := true
|
||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||
if err != nil {
|
||||
@@ -114,6 +117,9 @@ func NewServer(
|
||||
} else {
|
||||
//nolint:gosec
|
||||
syncLim = int32(syncLimParsed)
|
||||
if syncLim < 0 {
|
||||
syncLimEnabled = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,7 +139,8 @@ func NewServer(
|
||||
|
||||
loginFilter: newLoginFilter(),
|
||||
|
||||
syncLim: syncLim,
|
||||
syncLim: syncLim,
|
||||
syncLimEnabled: syncLimEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -211,7 +218,7 @@ func (s *Server) Job(srv proto.ManagementService_JobServer) error {
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
if s.syncSem.Load() >= s.syncLim {
|
||||
if s.syncLimEnabled && s.syncSem.Load() >= s.syncLim {
|
||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||
}
|
||||
s.syncSem.Add(1)
|
||||
@@ -232,6 +239,9 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
userID, err := s.accountManager.GetUserIDByPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
s.syncSem.Add(-1)
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
return status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
@@ -301,6 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -308,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -479,6 +490,10 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
|
||||
}
|
||||
|
||||
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||
|
||||
@@ -48,6 +48,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -231,7 +232,7 @@ func BuildManager(
|
||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||
am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1
|
||||
if am.singleAccountMode {
|
||||
if !isDomainValid(singleAccountModeDomain) {
|
||||
if !nbdomain.IsValidDomainNoWildcard(singleAccountModeDomain) {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
|
||||
}
|
||||
am.singleAccountModeDomain = singleAccountModeDomain
|
||||
@@ -402,7 +403,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
|
||||
return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
|
||||
if newSettings.DNSDomain != "" && !nbdomain.IsValidDomainNoWildcard(newSettings.DNSDomain) {
|
||||
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||
}
|
||||
|
||||
@@ -794,6 +795,19 @@ func IsEmbeddedIdp(i idp.Manager) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsLocalAuthDisabled checks if local (email/password) authentication is disabled.
|
||||
// Returns true only when using embedded IDP with local auth disabled in config.
|
||||
func IsLocalAuthDisabled(ctx context.Context, i idp.Manager) bool {
|
||||
if isNil(i) {
|
||||
return false
|
||||
}
|
||||
embeddedIdp, ok := i.(*idp.EmbeddedIdPManager)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return embeddedIdp.IsLocalAuthDisabled()
|
||||
}
|
||||
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
||||
@@ -1691,10 +1705,12 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
||||
return nil
|
||||
}
|
||||
|
||||
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
|
||||
// isDomainValid validates public/IDP domains using stricter rules than internal DNS domains.
|
||||
// Requires at least 2-char alphabetic TLD and no single-label domains.
|
||||
var publicDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
|
||||
|
||||
func isDomainValid(domain string) bool {
|
||||
return invalidDomainRegexp.MatchString(domain)
|
||||
return publicDomainRegexp.MatchString(domain)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) {
|
||||
|
||||
@@ -30,6 +30,12 @@ type Manager interface {
|
||||
autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
|
||||
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
|
||||
AcceptUserInvite(ctx context.Context, token, password string) error
|
||||
RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
|
||||
GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error)
|
||||
ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
|
||||
DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error
|
||||
|
||||
@@ -199,6 +199,11 @@ const (
|
||||
|
||||
UserPasswordChanged Activity = 103
|
||||
|
||||
UserInviteLinkCreated Activity = 104
|
||||
UserInviteLinkAccepted Activity = 105
|
||||
UserInviteLinkRegenerated Activity = 106
|
||||
UserInviteLinkDeleted Activity = 107
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -327,6 +332,11 @@ var activityMap = map[Activity]Code{
|
||||
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
|
||||
|
||||
UserPasswordChanged: {"User password changed", "user.password.change"},
|
||||
|
||||
UserInviteLinkCreated: {"User invite link created", "user.invite.link.create"},
|
||||
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
|
||||
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
|
||||
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -9,10 +9,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/rs/cors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
@@ -68,6 +69,13 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
if err := bypass.AddBypassPath("/api/setup"); err != nil {
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
// Public invite endpoints (tokens start with nbi_)
|
||||
if err := bypass.AddBypassPath("/api/users/invites/nbi_*"); err != nil {
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
|
||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||
@@ -122,16 +130,18 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||
}
|
||||
|
||||
// Check if embedded IdP is enabled
|
||||
// Check if embedded IdP is enabled for instance manager
|
||||
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
|
||||
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController)
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
|
||||
users.AddEndpoints(accountManager, router)
|
||||
users.AddInvitesEndpoints(accountManager, router)
|
||||
users.AddPublicInvitesEndpoints(accountManager, router)
|
||||
setup_keys.AddEndpoints(accountManager, router)
|
||||
policies.AddEndpoints(accountManager, LocationManager, router)
|
||||
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
|
||||
@@ -145,6 +155,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router)
|
||||
|
||||
// Mount embedded IdP handler at /oauth2 path if configured
|
||||
if embeddedIdpEnabled {
|
||||
|
||||
@@ -36,24 +36,22 @@ const (
|
||||
|
||||
// handler is a handler that handles the server.Account HTTP endpoints
|
||||
type handler struct {
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
embeddedIdpEnabled bool
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) {
|
||||
accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled)
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
|
||||
accountsHandler := newHandler(accountManager, settingsManager)
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new handler HTTP handler
|
||||
func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler {
|
||||
func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
embeddedIdpEnabled: embeddedIdpEnabled,
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,7 +163,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled)
|
||||
resp := toAccountResponse(accountID, settings, meta, onboarding)
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
@@ -292,7 +290,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled)
|
||||
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
@@ -321,7 +319,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account {
|
||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
|
||||
jwtAllowGroups := settings.JWTAllowGroups
|
||||
if jwtAllowGroups == nil {
|
||||
jwtAllowGroups = []string{}
|
||||
@@ -341,7 +339,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
||||
DnsDomain: &settings.DNSDomain,
|
||||
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
||||
EmbeddedIdpEnabled: &embeddedIdpEnabled,
|
||||
EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled,
|
||||
LocalAuthDisabled: &settings.LocalAuthDisabled,
|
||||
}
|
||||
|
||||
if settings.NetworkRange.IsValid() {
|
||||
|
||||
@@ -33,7 +33,6 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
AnyTimes()
|
||||
|
||||
return &handler{
|
||||
embeddedIdpEnabled: false,
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
return account.Settings, nil
|
||||
@@ -124,6 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: true,
|
||||
expectedID: accountID,
|
||||
@@ -148,6 +148,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -172,6 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr("latest"),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -196,6 +198,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -220,6 +223,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -244,6 +248,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
|
||||
@@ -28,6 +28,15 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// AddVersionEndpoint registers the authenticated version endpoint.
|
||||
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
instanceManager: instanceManager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// getInstanceStatus returns the instance status including whether setup is required.
|
||||
// This endpoint is unauthenticated.
|
||||
func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -37,7 +46,7 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(r.Context()).Infof("instance setup status: %v", setupRequired)
|
||||
util.WriteJSONObject(r.Context(), w, api.InstanceStatus{
|
||||
SetupRequired: setupRequired,
|
||||
})
|
||||
@@ -65,3 +74,29 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
|
||||
Email: userData.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// getVersionInfo returns version information for NetBird components.
|
||||
// This endpoint requires authentication.
|
||||
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
|
||||
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)
|
||||
util.WriteErrorResponse("failed to get version info", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.InstanceVersionInfo{
|
||||
ManagementCurrentVersion: versionInfo.CurrentVersion,
|
||||
ManagementUpdateAvailable: versionInfo.ManagementUpdateAvailable,
|
||||
}
|
||||
|
||||
if versionInfo.DashboardVersion != "" {
|
||||
resp.DashboardAvailableVersion = &versionInfo.DashboardVersion
|
||||
}
|
||||
|
||||
if versionInfo.ManagementVersion != "" {
|
||||
resp.ManagementAvailableVersion = &versionInfo.ManagementVersion
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ type mockInstanceManager struct {
|
||||
isSetupRequired bool
|
||||
isSetupRequiredFn func(ctx context.Context) (bool, error)
|
||||
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) {
|
||||
@@ -66,6 +67,18 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
|
||||
if m.getVersionInfoFn != nil {
|
||||
return m.getVersionInfoFn(ctx)
|
||||
}
|
||||
return &nbinstance.VersionInfo{
|
||||
CurrentVersion: "0.34.0",
|
||||
DashboardVersion: "2.0.0",
|
||||
ManagementVersion: "0.35.0",
|
||||
ManagementUpdateAvailable: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
|
||||
|
||||
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
|
||||
@@ -279,3 +292,44 @@ func TestSetup_ManagerError(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
func TestGetVersionInfo_Success(t *testing.T) {
|
||||
manager := &mockInstanceManager{}
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response api.InstanceVersionInfo
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "0.34.0", response.ManagementCurrentVersion)
|
||||
assert.NotNil(t, response.DashboardAvailableVersion)
|
||||
assert.Equal(t, "2.0.0", *response.DashboardAvailableVersion)
|
||||
assert.NotNil(t, response.ManagementAvailableVersion)
|
||||
assert.Equal(t, "0.35.0", *response.ManagementAvailableVersion)
|
||||
assert.True(t, response.ManagementUpdateAvailable)
|
||||
}
|
||||
|
||||
func TestGetVersionInfo_Error(t *testing.T) {
|
||||
manager := &mockInstanceManager{
|
||||
getVersionInfoFn: func(ctx context.Context) (*nbinstance.VersionInfo, error) {
|
||||
return nil, errors.New("failed to fetch versions")
|
||||
},
|
||||
}
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -26,11 +27,12 @@ import (
|
||||
// Handler is a handler that returns peers of the account
|
||||
type Handler struct {
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
networkMapController network_map.Controller
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
|
||||
peersHandler := NewHandler(accountManager, networkMapController)
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
|
||||
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
|
||||
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
@@ -42,10 +44,11 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap
|
||||
}
|
||||
|
||||
// NewHandler creates a new peers Handler
|
||||
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
|
||||
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller, permissionsManager permissions.Manager) *Handler {
|
||||
return &Handler{
|
||||
accountManager: accountManager,
|
||||
networkMapController: networkMapController,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,13 +362,19 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
||||
err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.NewPermissionDeniedError(), w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -13,13 +13,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"go.uber.org/mock/gomock"
|
||||
ugomock "go.uber.org/mock/gomock"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -102,7 +104,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
},
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
ctrl := ugomock.NewController(t)
|
||||
|
||||
networkMapController := network_map.NewMockController(ctrl)
|
||||
networkMapController.EXPECT().
|
||||
@@ -110,6 +112,10 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
Return("domain").
|
||||
AnyTimes()
|
||||
|
||||
ctrl2 := gomock.NewController(t)
|
||||
permissionsManager := permissions.NewMockManager(ctrl2)
|
||||
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
return &Handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
@@ -199,6 +205,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
},
|
||||
},
|
||||
networkMapController: networkMapController,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
263
management/server/http/handlers/users/invites_handler.go
Normal file
263
management/server/http/handlers/users/invites_handler.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// publicInviteRateLimiter limits public invite requests by IP address to prevent brute-force attacks
|
||||
var publicInviteRateLimiter = middleware.NewAPIRateLimiter(&middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: 10, // 10 attempts per minute per IP
|
||||
Burst: 5, // Allow burst of 5 requests
|
||||
CleanupInterval: 10 * time.Minute,
|
||||
LimiterTTL: 30 * time.Minute,
|
||||
})
|
||||
|
||||
// toUserInviteResponse converts a UserInvite to an API response.
|
||||
func toUserInviteResponse(invite *types.UserInvite) api.UserInvite {
|
||||
autoGroups := invite.UserInfo.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
}
|
||||
var inviteLink *string
|
||||
if invite.InviteToken != "" {
|
||||
inviteLink = &invite.InviteToken
|
||||
}
|
||||
return api.UserInvite{
|
||||
Id: invite.UserInfo.ID,
|
||||
Email: invite.UserInfo.Email,
|
||||
Name: invite.UserInfo.Name,
|
||||
Role: invite.UserInfo.Role,
|
||||
AutoGroups: autoGroups,
|
||||
ExpiresAt: invite.InviteExpiresAt.UTC(),
|
||||
CreatedAt: invite.InviteCreatedAt.UTC(),
|
||||
Expired: time.Now().After(invite.InviteExpiresAt),
|
||||
InviteToken: inviteLink,
|
||||
}
|
||||
}
|
||||
|
||||
// invitesHandler handles user invite operations
|
||||
type invitesHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
// AddInvitesEndpoints registers invite-related endpoints
|
||||
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
h := &invitesHandler{accountManager: accountManager}
|
||||
|
||||
// Authenticated endpoints (require admin)
|
||||
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
|
||||
func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
h := &invitesHandler{accountManager: accountManager}
|
||||
|
||||
// Create a subrouter for public invite endpoints with rate limiting middleware
|
||||
publicRouter := router.PathPrefix("/users/invites").Subrouter()
|
||||
publicRouter.Use(publicInviteRateLimiter.Middleware)
|
||||
|
||||
// Public endpoints (no auth required, protected by token and rate limited)
|
||||
publicRouter.HandleFunc("/{token}", h.getInviteInfo).Methods("GET", "OPTIONS")
|
||||
publicRouter.HandleFunc("/{token}/accept", h.acceptInvite).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// listInvites handles GET /api/users/invites
|
||||
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]api.UserInvite, 0, len(invites))
|
||||
for _, invite := range invites {
|
||||
resp = append(resp, toUserInviteResponse(invite))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
// createInvite handles POST /api/users/invites
|
||||
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.UserInviteCreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
invite := &types.UserInfo{
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
Role: req.Role,
|
||||
AutoGroups: req.AutoGroups,
|
||||
}
|
||||
|
||||
expiresIn := 0
|
||||
if req.ExpiresIn != nil {
|
||||
expiresIn = *req.ExpiresIn
|
||||
}
|
||||
|
||||
result, err := h.accountManager.CreateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, invite, expiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
result.InviteCreatedAt = time.Now().UTC()
|
||||
resp := toUserInviteResponse(result)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
// getInviteInfo handles GET /api/users/invites/{token}
|
||||
func (h *invitesHandler) getInviteInfo(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
vars := mux.Vars(r)
|
||||
token := vars["token"]
|
||||
if token == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := h.accountManager.GetUserInviteInfo(r.Context(), token)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := info.ExpiresAt.UTC()
|
||||
util.WriteJSONObject(r.Context(), w, &api.UserInviteInfo{
|
||||
Email: info.Email,
|
||||
Name: info.Name,
|
||||
ExpiresAt: expiresAt,
|
||||
Valid: info.Valid,
|
||||
InvitedBy: info.InvitedBy,
|
||||
})
|
||||
}
|
||||
|
||||
// acceptInvite handles POST /api/users/invites/{token}/accept
|
||||
func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
vars := mux.Vars(r)
|
||||
token := vars["token"]
|
||||
if token == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.UserInviteAcceptRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.AcceptUserInvite(r.Context(), token, req.Password)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &api.UserInviteAcceptResponse{Success: true})
|
||||
}
|
||||
|
||||
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
|
||||
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.UserInviteRegenerateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
// Allow empty body (io.EOF) - expiresIn is optional
|
||||
if !errors.Is(err, io.EOF) {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
expiresIn := 0
|
||||
if req.ExpiresIn != nil {
|
||||
expiresIn = *req.ExpiresIn
|
||||
}
|
||||
|
||||
result, err := h.accountManager.RegenerateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID, expiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := result.InviteExpiresAt.UTC()
|
||||
util.WriteJSONObject(r.Context(), w, &api.UserInviteRegenerateResponse{
|
||||
InviteToken: result.InviteToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
// deleteInvite handles DELETE /api/users/invites/{inviteId}
|
||||
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
659
management/server/http/handlers/users/invites_handler_test.go
Normal file
659
management/server/http/handlers/users/invites_handler_test.go
Normal file
@@ -0,0 +1,659 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "test-account-id"
|
||||
testUserID = "test-user-id"
|
||||
testInviteID = "test-invite-id"
|
||||
testInviteToken = "nbi_testtoken123456789012345678"
|
||||
testEmail = "invite@example.com"
|
||||
testName = "Test User"
|
||||
)
|
||||
|
||||
func setupInvitesTestHandler(am *mock_server.MockAccountManager) *invitesHandler {
|
||||
return &invitesHandler{
|
||||
accountManager: am,
|
||||
}
|
||||
}
|
||||
|
||||
func TestListInvites(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
testInvites := []*types.UserInvite{
|
||||
{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "invite-1",
|
||||
Email: "user1@example.com",
|
||||
Name: "User One",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1"},
|
||||
},
|
||||
InviteExpiresAt: now.Add(24 * time.Hour),
|
||||
InviteCreatedAt: now,
|
||||
},
|
||||
{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "invite-2",
|
||||
Email: "user2@example.com",
|
||||
Name: "User Two",
|
||||
Role: "admin",
|
||||
AutoGroups: nil,
|
||||
},
|
||||
InviteExpiresAt: now.Add(-1 * time.Hour), // Expired
|
||||
InviteCreatedAt: now.Add(-48 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "successful list",
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
return testInvites, nil
|
||||
},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "empty list",
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
return []*types.UserInvite{}, nil
|
||||
},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
},
|
||||
expectedCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
ListUserInvitesFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/users/invites", nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.listInvites(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp []api.UserInvite
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp, tc.expectedCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInvite(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(72 * time.Hour)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
requestBody string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
|
||||
}{
|
||||
{
|
||||
name: "successful create",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":["group-1"]}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: testInviteID,
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
Role: invite.Role,
|
||||
AutoGroups: invite.AutoGroups,
|
||||
Status: string(types.UserStatusInvited),
|
||||
},
|
||||
InviteToken: testInviteToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful create with custom expiration",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"admin","auto_groups":[],"expires_in":3600}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
assert.Equal(t, 3600, expiresIn)
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: testInviteID,
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
Role: invite.Role,
|
||||
AutoGroups: []string{},
|
||||
Status: string(types.UserStatusInvited),
|
||||
},
|
||||
InviteToken: testInviteToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user already exists",
|
||||
requestBody: `{"email":"existing@example.com","name":"Existing User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusConflict,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite already exists",
|
||||
requestBody: `{"email":"invited@example.com","name":"Invited User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusConflict,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embedded IDP not enabled",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "local auth disabled",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
requestBody: `{invalid json}`,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
CreateUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users/invites", bytes.NewBufferString(tc.requestBody))
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.createInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInvite
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testInviteID, resp.Id)
|
||||
assert.NotNil(t, resp.InviteToken)
|
||||
assert.NotEmpty(t, *resp.InviteToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInviteInfo(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
token string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
|
||||
}{
|
||||
{
|
||||
name: "successful get valid invite",
|
||||
token: testInviteToken,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return &types.UserInviteInfo{
|
||||
Email: testEmail,
|
||||
Name: testName,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
Valid: true,
|
||||
InvitedBy: "Admin User",
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful get expired invite",
|
||||
token: testInviteToken,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return &types.UserInviteInfo{
|
||||
Email: testEmail,
|
||||
Name: testName,
|
||||
ExpiresAt: now.Add(-24 * time.Hour),
|
||||
Valid: false,
|
||||
InvitedBy: "Admin User",
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
token: "nbi_invalidtoken1234567890123456",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return nil, status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid token format",
|
||||
token: "invalid",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid invite token")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing token",
|
||||
token: "",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
GetUserInviteInfoFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/users/invites/"+tc.token, nil)
|
||||
if tc.token != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.getInviteInfo(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInviteInfo
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testEmail, resp.Email)
|
||||
assert.Equal(t, testName, resp.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcceptInvite(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
token string
|
||||
requestBody string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, token, password string) error
|
||||
}{
|
||||
{
|
||||
name: "successful accept",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
token: "nbi_invalidtoken1234567890123456",
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite expired",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "invite has expired")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embedded IDP not enabled",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "local auth disabled",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing token",
|
||||
token: "",
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
token: testInviteToken,
|
||||
requestBody: `{invalid}`,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
mockFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "password too short",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"Short1!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters long")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "password missing digit",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"NoDigitPass!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must contain at least one digit")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "password missing uppercase",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"nouppercase1!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must contain at least one uppercase letter")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "password missing special character",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"NoSpecial123"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must contain at least one special character")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
AcceptUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.token+"/accept", bytes.NewBufferString(tc.requestBody))
|
||||
if tc.token != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.acceptInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInviteAcceptResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.Success)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegenerateInvite(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(72 * time.Hour)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
inviteID string
|
||||
requestBody string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
|
||||
}{
|
||||
{
|
||||
name: "successful regenerate with empty body",
|
||||
inviteID: testInviteID,
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
assert.Equal(t, 0, expiresIn)
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: inviteID,
|
||||
Email: testEmail,
|
||||
},
|
||||
InviteToken: "nbi_newtoken12345678901234567890",
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful regenerate with custom expiration",
|
||||
inviteID: testInviteID,
|
||||
requestBody: `{"expires_in":7200}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
assert.Equal(t, 7200, expiresIn)
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: inviteID,
|
||||
Email: testEmail,
|
||||
},
|
||||
InviteToken: "nbi_newtoken12345678901234567890",
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
inviteID: "non-existent-invite",
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
inviteID: testInviteID,
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing invite ID",
|
||||
inviteID: "",
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON should return error",
|
||||
inviteID: testInviteID,
|
||||
requestBody: `{invalid json}`,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
RegenerateUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
var body io.Reader
|
||||
if tc.requestBody != "" {
|
||||
body = bytes.NewBufferString(tc.requestBody)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.inviteID+"/regenerate", body)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
if tc.inviteID != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.regenerateInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInviteRegenerateResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.InviteToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteInvite(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
inviteID string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
}{
|
||||
{
|
||||
name: "successful delete",
|
||||
inviteID: testInviteID,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
inviteID: "non-existent-invite",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
inviteID: testInviteID,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return status.NewPermissionDeniedError()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embedded IDP not enabled",
|
||||
inviteID: testInviteID,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing invite ID",
|
||||
inviteID: "",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
DeleteUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/users/invites/"+tc.inviteID, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
if tc.inviteID != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.deleteInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,14 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
@@ -144,3 +148,25 @@ func (rl *APIRateLimiter) Reset(key string) {
|
||||
defer rl.mu.Unlock()
|
||||
delete(rl.limiters, key)
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that rate limits requests by client IP.
|
||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := getClientIP(r)
|
||||
if !rl.Allow(clientIP) {
|
||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP extracts the client IP address from the request.
|
||||
func getClientIP(r *http.Request) string {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
158
management/server/http/middleware/rate_limiter_test.go
Normal file
158
management/server/http/middleware/rate_limiter_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAPIRateLimiter_Allow(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60, // 1 per second
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// First two requests should be allowed (burst)
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
|
||||
// Third request should be denied (exceeded burst)
|
||||
assert.False(t, rl.Allow("test-key"))
|
||||
|
||||
// Different key should be allowed
|
||||
assert.True(t, rl.Allow("different-key"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Middleware(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60, // 1 per second
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// Create a simple handler that returns 200 OK
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Wrap with rate limiter middleware
|
||||
handler := rl.Middleware(nextHandler)
|
||||
|
||||
// First two requests should pass (burst)
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1)
|
||||
}
|
||||
|
||||
// Third request should be rate limited
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
handler := rl.Middleware(nextHandler)
|
||||
|
||||
// Request from first IP
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.1:12345"
|
||||
rr1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr1, req1)
|
||||
assert.Equal(t, http.StatusOK, rr1.Code)
|
||||
|
||||
// Second request from first IP should be rate limited
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.1:12345"
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rr2.Code)
|
||||
|
||||
// Request from different IP should be allowed
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req3.RemoteAddr = "192.168.1.2:12345"
|
||||
rr3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr3, req3)
|
||||
assert.Equal(t, http.StatusOK, rr3.Code)
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "remote addr with port",
|
||||
remoteAddr: "192.168.1.1:12345",
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "remote addr without port",
|
||||
remoteAddr: "192.168.1.1",
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 with port",
|
||||
remoteAddr: "[::1]:12345",
|
||||
expected: "::1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 without port",
|
||||
remoteAddr: "::1",
|
||||
expected: "::1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = tc.remoteAddr
|
||||
assert.Equal(t, tc.expected, getClientIP(req))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Reset(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// Use up the burst
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
assert.False(t, rl.Allow("test-key"))
|
||||
|
||||
// Reset the limiter
|
||||
rl.Reset("test-key")
|
||||
|
||||
// Should be allowed again
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
proxyController := integrations.NewController(store)
|
||||
userManager := users.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{})
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
@@ -20,7 +20,7 @@ const (
|
||||
staticClientCLI = "netbird-cli"
|
||||
defaultCLIRedirectURL1 = "http://localhost:53000/"
|
||||
defaultCLIRedirectURL2 = "http://localhost:54000/"
|
||||
defaultScopes = "openid profile email"
|
||||
defaultScopes = "openid profile email groups"
|
||||
defaultUserIDClaim = "sub"
|
||||
)
|
||||
|
||||
@@ -43,6 +43,11 @@ type EmbeddedIdPConfig struct {
|
||||
Owner *OwnerConfig
|
||||
// SignKeyRefreshEnabled enables automatic key rotation for signing keys
|
||||
SignKeyRefreshEnabled bool
|
||||
// LocalAuthDisabled disables the local (email/password) authentication connector.
|
||||
// When true, users cannot authenticate via email/password, only via external identity providers.
|
||||
// Existing local users are preserved and will be able to login again if re-enabled.
|
||||
// Cannot be enabled if no external identity provider connectors are configured.
|
||||
LocalAuthDisabled bool
|
||||
}
|
||||
|
||||
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
|
||||
@@ -105,6 +110,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
Issuer: "NetBird",
|
||||
Theme: "light",
|
||||
},
|
||||
// Always enable password DB initially - we disable the local connector after startup if needed.
|
||||
// This ensures Dex has at least one connector during initialization.
|
||||
EnablePasswordDB: true,
|
||||
StaticClients: []storage.Client{
|
||||
{
|
||||
@@ -192,11 +199,32 @@ func NewEmbeddedIdPManager(ctx context.Context, config *EmbeddedIdPConfig, appMe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("initializing embedded Dex IDP with config: %+v", config)
|
||||
|
||||
provider, err := dex.NewProviderFromYAML(ctx, yamlConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err)
|
||||
}
|
||||
|
||||
// If local auth is disabled, validate that other connectors exist
|
||||
if config.LocalAuthDisabled {
|
||||
hasOthers, err := provider.HasNonLocalConnectors(ctx)
|
||||
if err != nil {
|
||||
_ = provider.Stop(ctx)
|
||||
return nil, fmt.Errorf("failed to check connectors: %w", err)
|
||||
}
|
||||
if !hasOthers {
|
||||
_ = provider.Stop(ctx)
|
||||
return nil, fmt.Errorf("cannot disable local authentication: no other identity providers configured")
|
||||
}
|
||||
// Ensure local connector is removed (it might exist from a previous run)
|
||||
if err := provider.DisableLocalAuth(ctx); err != nil {
|
||||
_ = provider.Stop(ctx)
|
||||
return nil, fmt.Errorf("failed to disable local auth: %w", err)
|
||||
}
|
||||
log.WithContext(ctx).Info("local authentication disabled - only external identity providers can be used")
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer)
|
||||
|
||||
return &EmbeddedIdPManager{
|
||||
@@ -281,6 +309,8 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*
|
||||
return nil, fmt.Errorf("failed to list users: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(users))
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, user := range users {
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{
|
||||
@@ -290,11 +320,17 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*
|
||||
})
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(indexedUsers[UnsetAccountID]))
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in the embedded IdP.
|
||||
func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
if m.config.LocalAuthDisabled {
|
||||
return nil, fmt.Errorf("local user creation is disabled")
|
||||
}
|
||||
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
@@ -364,6 +400,10 @@ func (m *EmbeddedIdPManager) GetUserByEmail(ctx context.Context, email string) (
|
||||
// Unlike CreateUser which auto-generates a password, this method uses the provided password.
|
||||
// This is useful for instance setup where the user provides their own password.
|
||||
func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) {
|
||||
if m.config.LocalAuthDisabled {
|
||||
return nil, fmt.Errorf("local user creation is disabled")
|
||||
}
|
||||
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
@@ -553,3 +593,13 @@ func (m *EmbeddedIdPManager) GetClientIDs() []string {
|
||||
func (m *EmbeddedIdPManager) GetUserIDClaim() string {
|
||||
return defaultUserIDClaim
|
||||
}
|
||||
|
||||
// IsLocalAuthDisabled returns whether local authentication is disabled based on configuration.
|
||||
func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool {
|
||||
return m.config.LocalAuthDisabled
|
||||
}
|
||||
|
||||
// HasNonLocalConnectors checks if there are any identity provider connectors other than local.
|
||||
func (m *EmbeddedIdPManager) HasNonLocalConnectors(ctx context.Context) (bool, error) {
|
||||
return m.provider.HasNonLocalConnectors(ctx)
|
||||
}
|
||||
|
||||
@@ -370,3 +370,234 @@ func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPManager_LocalAuthDisabled(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("cannot start with local auth disabled without other connectors", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
LocalAuthDisabled: true,
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no other identity providers configured")
|
||||
})
|
||||
|
||||
t.Run("local auth enabled by default", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager.Stop(ctx) }()
|
||||
|
||||
// Verify local auth is enabled by default
|
||||
assert.False(t, manager.IsLocalAuthDisabled())
|
||||
})
|
||||
|
||||
t.Run("start with local auth disabled when connector exists", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbFile := filepath.Join(tmpDir, "dex.db")
|
||||
|
||||
// First, create a manager with local auth enabled and add a connector
|
||||
config1 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a user
|
||||
userData, err := manager1.CreateUser(ctx, "preserved@example.com", "Preserved User", "account1", "admin@example.com")
|
||||
require.NoError(t, err)
|
||||
userID := userData.ID
|
||||
|
||||
// Add an external connector (Google doesn't require OIDC discovery)
|
||||
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
|
||||
ID: "google-test",
|
||||
Name: "Google Test",
|
||||
Type: "google",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Stop the first manager
|
||||
err = manager1.Stop(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now create a new manager with local auth disabled
|
||||
config2 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
LocalAuthDisabled: true,
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager2.Stop(ctx) }()
|
||||
|
||||
// Verify local auth is disabled via config
|
||||
assert.True(t, manager2.IsLocalAuthDisabled())
|
||||
|
||||
// Verify the user still exists in storage (just can't login via local)
|
||||
lookedUp, err := manager2.GetUserDataByID(ctx, userID, AppMetadata{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "preserved@example.com", lookedUp.Email)
|
||||
})
|
||||
|
||||
t.Run("CreateUser fails when local auth is disabled", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbFile := filepath.Join(tmpDir, "dex.db")
|
||||
|
||||
// First, create a manager and add an external connector
|
||||
config1 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
|
||||
ID: "google-test",
|
||||
Name: "Google Test",
|
||||
Type: "google",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager1.Stop(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create manager with local auth disabled
|
||||
config2 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
LocalAuthDisabled: true,
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager2.Stop(ctx) }()
|
||||
|
||||
// Try to create a user - should fail
|
||||
_, err = manager2.CreateUser(ctx, "newuser@example.com", "New User", "account1", "admin@example.com")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "local user creation is disabled")
|
||||
})
|
||||
|
||||
t.Run("CreateUserWithPassword fails when local auth is disabled", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbFile := filepath.Join(tmpDir, "dex.db")
|
||||
|
||||
// First, create a manager and add an external connector
|
||||
config1 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
|
||||
ID: "google-test",
|
||||
Name: "Google Test",
|
||||
Type: "google",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager1.Stop(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create manager with local auth disabled
|
||||
config2 := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
LocalAuthDisabled: true,
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: dbFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager2.Stop(ctx) }()
|
||||
|
||||
// Try to create a user with password - should fail
|
||||
_, err = manager2.CreateUserWithPassword(ctx, "newuser@example.com", "SecurePass123!", "New User")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "local user creation is disabled")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,18 +2,54 @@ package instance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
// Version endpoints
|
||||
managementVersionURL = "https://pkgs.netbird.io/releases/latest/version"
|
||||
dashboardReleasesURL = "https://api.github.com/repos/netbirdio/dashboard/releases/latest"
|
||||
|
||||
// Cache TTL for version information
|
||||
versionCacheTTL = 60 * time.Minute
|
||||
|
||||
// HTTP client timeout
|
||||
httpTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// VersionInfo contains version information for NetBird components
|
||||
type VersionInfo struct {
|
||||
// CurrentVersion is the running management server version
|
||||
CurrentVersion string
|
||||
// DashboardVersion is the latest available dashboard version from GitHub
|
||||
DashboardVersion string
|
||||
// ManagementVersion is the latest available management version from GitHub
|
||||
ManagementVersion string
|
||||
// ManagementUpdateAvailable indicates if a newer management version is available
|
||||
ManagementUpdateAvailable bool
|
||||
}
|
||||
|
||||
// githubRelease represents a GitHub release response
|
||||
type githubRelease struct {
|
||||
TagName string `json:"tag_name"`
|
||||
}
|
||||
|
||||
// Manager handles instance-level operations like initial setup.
|
||||
type Manager interface {
|
||||
// IsSetupRequired checks if instance setup is required.
|
||||
@@ -23,6 +59,9 @@ type Manager interface {
|
||||
// CreateOwnerUser creates the initial owner user in the embedded IDP.
|
||||
// This should only be called when IsSetupRequired returns true.
|
||||
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
|
||||
// GetVersionInfo returns version information for NetBird components.
|
||||
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
|
||||
}
|
||||
|
||||
// DefaultManager is the default implementation of Manager.
|
||||
@@ -32,6 +71,12 @@ type DefaultManager struct {
|
||||
|
||||
setupRequired bool
|
||||
setupMu sync.RWMutex
|
||||
|
||||
// Version caching
|
||||
httpClient *http.Client
|
||||
versionMu sync.RWMutex
|
||||
cachedVersions *VersionInfo
|
||||
lastVersionFetch time.Time
|
||||
}
|
||||
|
||||
// NewManager creates a new instance manager.
|
||||
@@ -43,6 +88,9 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager)
|
||||
store: store,
|
||||
embeddedIdpManager: embeddedIdp,
|
||||
setupRequired: false,
|
||||
httpClient: &http.Client{
|
||||
Timeout: httpTimeout,
|
||||
},
|
||||
}
|
||||
|
||||
if embeddedIdp != nil {
|
||||
@@ -56,13 +104,22 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) loadSetupRequired(ctx context.Context) error {
|
||||
// Check if there are any accounts in the NetBird store
|
||||
numAccounts, err := m.store.GetAccountsCounter(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hasAccounts := numAccounts > 0
|
||||
|
||||
// Check if there are any users in the embedded IdP (Dex)
|
||||
users, err := m.embeddedIdpManager.GetAllAccounts(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hasLocalUsers := len(users) > 0
|
||||
|
||||
m.setupMu.Lock()
|
||||
m.setupRequired = len(users) == 0
|
||||
m.setupRequired = !(hasAccounts || hasLocalUsers)
|
||||
m.setupMu.Unlock()
|
||||
|
||||
return nil
|
||||
@@ -134,3 +191,130 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetVersionInfo returns version information for NetBird components.
|
||||
func (m *DefaultManager) GetVersionInfo(ctx context.Context) (*VersionInfo, error) {
|
||||
m.versionMu.RLock()
|
||||
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
|
||||
cached := *m.cachedVersions
|
||||
m.versionMu.RUnlock()
|
||||
return &cached, nil
|
||||
}
|
||||
m.versionMu.RUnlock()
|
||||
|
||||
return m.fetchVersionInfo(ctx)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) fetchVersionInfo(ctx context.Context) (*VersionInfo, error) {
|
||||
m.versionMu.Lock()
|
||||
// Double-check after acquiring write lock
|
||||
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
|
||||
cached := *m.cachedVersions
|
||||
m.versionMu.Unlock()
|
||||
return &cached, nil
|
||||
}
|
||||
m.versionMu.Unlock()
|
||||
|
||||
info := &VersionInfo{
|
||||
CurrentVersion: version.NetbirdVersion(),
|
||||
}
|
||||
|
||||
// Fetch management version from pkgs.netbird.io (plain text)
|
||||
mgmtVersion, err := m.fetchPlainTextVersion(ctx, managementVersionURL)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to fetch management version: %v", err)
|
||||
} else {
|
||||
info.ManagementVersion = mgmtVersion
|
||||
info.ManagementUpdateAvailable = isNewerVersion(info.CurrentVersion, mgmtVersion)
|
||||
}
|
||||
|
||||
// Fetch dashboard version from GitHub
|
||||
dashVersion, err := m.fetchGitHubRelease(ctx, dashboardReleasesURL)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to fetch dashboard version from GitHub: %v", err)
|
||||
} else {
|
||||
info.DashboardVersion = dashVersion
|
||||
}
|
||||
|
||||
// Update cache
|
||||
m.versionMu.Lock()
|
||||
m.cachedVersions = info
|
||||
m.lastVersionFetch = time.Now()
|
||||
m.versionMu.Unlock()
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// isNewerVersion returns true if latestVersion is greater than currentVersion
|
||||
func isNewerVersion(currentVersion, latestVersion string) bool {
|
||||
current, err := goversion.NewVersion(currentVersion)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
latest, err := goversion.NewVersion(latestVersion)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return latest.GreaterThan(current)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) fetchPlainTextVersion(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 100))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(body)), nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) fetchGitHubRelease(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release githubRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return "", fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
// Remove 'v' prefix if present
|
||||
tag := release.TagName
|
||||
if len(tag) > 0 && tag[0] == 'v' {
|
||||
tag = tag[1:]
|
||||
}
|
||||
|
||||
return tag, nil
|
||||
}
|
||||
|
||||
285
management/server/instance/version_test.go
Normal file
285
management/server/instance/version_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package instance
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockRoundTripper implements http.RoundTripper for testing
|
||||
type mockRoundTripper struct {
|
||||
callCount atomic.Int32
|
||||
managementVersion string
|
||||
dashboardVersion string
|
||||
}
|
||||
|
||||
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
m.callCount.Add(1)
|
||||
|
||||
var body string
|
||||
if strings.Contains(req.URL.String(), "pkgs.netbird.io") {
|
||||
// Plain text response for management version
|
||||
body = m.managementVersion
|
||||
} else if strings.Contains(req.URL.String(), "github.com") {
|
||||
// JSON response for dashboard version
|
||||
jsonResp, _ := json.Marshal(githubRelease{TagName: "v" + m.dashboardVersion})
|
||||
body = string(jsonResp)
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewBufferString(body)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestDefaultManager_GetVersionInfo_ReturnsCurrentVersion(t *testing.T) {
|
||||
mockTransport := &mockRoundTripper{
|
||||
managementVersion: "0.65.0",
|
||||
dashboardVersion: "2.10.0",
|
||||
}
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Transport: mockTransport},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
info, err := m.GetVersionInfo(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// CurrentVersion should always be set
|
||||
assert.NotEmpty(t, info.CurrentVersion)
|
||||
assert.Equal(t, "0.65.0", info.ManagementVersion)
|
||||
assert.Equal(t, "2.10.0", info.DashboardVersion)
|
||||
assert.Equal(t, int32(2), mockTransport.callCount.Load()) // 2 calls: management + dashboard
|
||||
}
|
||||
|
||||
func TestDefaultManager_GetVersionInfo_CachesResults(t *testing.T) {
|
||||
mockTransport := &mockRoundTripper{
|
||||
managementVersion: "0.65.0",
|
||||
dashboardVersion: "2.10.0",
|
||||
}
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Transport: mockTransport},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First call
|
||||
info1, err := m.GetVersionInfo(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, info1.CurrentVersion)
|
||||
assert.Equal(t, "0.65.0", info1.ManagementVersion)
|
||||
|
||||
initialCallCount := mockTransport.callCount.Load()
|
||||
|
||||
// Second call should use cache (no additional HTTP calls)
|
||||
info2, err := m.GetVersionInfo(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, info1.CurrentVersion, info2.CurrentVersion)
|
||||
assert.Equal(t, info1.ManagementVersion, info2.ManagementVersion)
|
||||
assert.Equal(t, info1.DashboardVersion, info2.DashboardVersion)
|
||||
|
||||
// Verify no additional HTTP calls were made (cache was used)
|
||||
assert.Equal(t, initialCallCount, mockTransport.callCount.Load())
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_ParsesTagName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tagName string
|
||||
expected string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "tag with v prefix",
|
||||
tagName: "v1.2.3",
|
||||
expected: "1.2.3",
|
||||
},
|
||||
{
|
||||
name: "tag without v prefix",
|
||||
tagName: "1.2.3",
|
||||
expected: "1.2.3",
|
||||
},
|
||||
{
|
||||
name: "tag with prerelease",
|
||||
tagName: "v2.0.0-beta.1",
|
||||
expected: "2.0.0-beta.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(githubRelease{TagName: tc.tagName})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
version, err := m.fetchGitHubRelease(context.Background(), server.URL)
|
||||
|
||||
if tc.shouldError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, version)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_HandlesErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
}{
|
||||
{
|
||||
name: "not found",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: `{"message": "Not Found"}`,
|
||||
},
|
||||
{
|
||||
name: "rate limited",
|
||||
statusCode: http.StatusForbidden,
|
||||
body: `{"message": "API rate limit exceeded"}`,
|
||||
},
|
||||
{
|
||||
name: "server error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
body: `{"message": "Internal Server Error"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tc.statusCode)
|
||||
_, _ = w.Write([]byte(tc.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_InvalidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{invalid json}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_ContextCancellation(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(1 * time.Second)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(githubRelease{TagName: "v1.0.0"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
_, err := m.fetchGitHubRelease(ctx, server.URL)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIsNewerVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentVersion string
|
||||
latestVersion string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "latest is newer - minor version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "0.65.0",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "latest is newer - patch version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "0.64.2",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "latest is newer - major version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "1.0.0",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "versions are equal",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "0.64.1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "current is newer - minor version",
|
||||
currentVersion: "0.65.0",
|
||||
latestVersion: "0.64.1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "current is newer - patch version",
|
||||
currentVersion: "0.64.2",
|
||||
latestVersion: "0.64.1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "development version",
|
||||
currentVersion: "development",
|
||||
latestVersion: "0.65.0",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid latest version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "invalid",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := isNewerVersion(tc.currentVersion, tc.latestVersion)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -139,6 +139,12 @@ type MockAccountManager struct {
|
||||
CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||
GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||
GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||
CreateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
|
||||
AcceptUserInviteFunc func(ctx context.Context, token, password string) error
|
||||
RegenerateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
|
||||
GetUserInviteInfoFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
|
||||
ListUserInvitesFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
|
||||
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
|
||||
@@ -713,6 +719,48 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
if am.CreateUserInviteFunc != nil {
|
||||
return am.CreateUserInviteFunc(ctx, accountID, initiatorUserID, invite, expiresIn)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error {
|
||||
if am.AcceptUserInviteFunc != nil {
|
||||
return am.AcceptUserInviteFunc(ctx, token, password)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method AcceptUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
if am.RegenerateUserInviteFunc != nil {
|
||||
return am.RegenerateUserInviteFunc(ctx, accountID, initiatorUserID, inviteID, expiresIn)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RegenerateUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
if am.GetUserInviteInfoFunc != nil {
|
||||
return am.GetUserInviteInfoFunc(ctx, token)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUserInviteInfo is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
if am.ListUserInvitesFunc != nil {
|
||||
return am.ListUserInvitesFunc(ctx, accountID, initiatorUserID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListUserInvites is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
if am.DeleteUserInviteFunc != nil {
|
||||
return am.DeleteUserInviteFunc(ctx, accountID, initiatorUserID, inviteID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
|
||||
if am.GetAccountIDFromUserAuthFunc != nil {
|
||||
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)
|
||||
|
||||
@@ -3,10 +3,10 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"regexp"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rs/xid"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -15,11 +15,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
|
||||
|
||||
var errInvalidDomainName = errors.New("invalid domain name")
|
||||
|
||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||
@@ -305,16 +304,18 @@ func validateGroups(list []string, groups map[string]*types.Group) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var domainMatcher = regexp.MustCompile(domainPattern)
|
||||
|
||||
func validateDomain(domain string) error {
|
||||
if !domainMatcher.MatchString(domain) {
|
||||
return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces")
|
||||
// validateDomain validates a nameserver match domain.
|
||||
// Converts unicode to punycode. Wildcards are not allowed for nameservers.
|
||||
func validateDomain(d string) error {
|
||||
if strings.HasPrefix(d, "*.") {
|
||||
return errors.New("wildcards not allowed")
|
||||
}
|
||||
|
||||
_, valid := dns.IsDomainName(domain)
|
||||
if !valid {
|
||||
return errInvalidDomainName
|
||||
// Nameservers allow trailing dot (FQDN format)
|
||||
toValidate := strings.TrimSuffix(d, ".")
|
||||
|
||||
if _, err := nbdomain.ValidateDomains([]string{toValidate}); err != nil {
|
||||
return fmt.Errorf("%w: %w", errInvalidDomainName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -901,82 +901,53 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// TestValidateDomain tests nameserver-specific domain validation.
|
||||
// Core domain validation is tested in shared/management/domain/validate_test.go.
|
||||
// This test only covers nameserver-specific behavior: wildcard rejection and unicode support.
|
||||
func TestValidateDomain(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
domain string
|
||||
errFunc require.ErrorAssertionFunc
|
||||
}{
|
||||
// Nameserver-specific: wildcards not allowed
|
||||
{
|
||||
name: "Valid domain name with multiple labels",
|
||||
domain: "123.example.com",
|
||||
name: "Wildcard prefix rejected",
|
||||
domain: "*.example.com",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Wildcard in middle rejected",
|
||||
domain: "a.*.example.com",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
// Nameserver-specific: unicode converted to punycode
|
||||
{
|
||||
name: "Unicode domain converted to punycode",
|
||||
domain: "münchen.de",
|
||||
errFunc: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "Valid domain name with hyphen",
|
||||
domain: "test-example.com",
|
||||
name: "Unicode domain all labels",
|
||||
domain: "中国.中国",
|
||||
errFunc: require.NoError,
|
||||
},
|
||||
// Basic validation still works (delegates to shared validation)
|
||||
{
|
||||
name: "Valid multi-label domain",
|
||||
domain: "example.com",
|
||||
errFunc: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "Valid domain name with only one label",
|
||||
domain: "example",
|
||||
name: "Valid single label",
|
||||
domain: "internal",
|
||||
errFunc: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "Valid domain name with trailing dot",
|
||||
domain: "example.",
|
||||
errFunc: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "Invalid wildcard domain name",
|
||||
domain: "*.example",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain name with leading dot",
|
||||
domain: ".com",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain name with dot only",
|
||||
domain: ".",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain name with double hyphen",
|
||||
domain: "test--example.com",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain name with a label exceeding 63 characters",
|
||||
domain: "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain name starting with a hyphen",
|
||||
name: "Invalid leading hyphen",
|
||||
domain: "-example.com",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain name ending with a hyphen",
|
||||
domain: "example.com-",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain with unicode",
|
||||
domain: "example?,.com",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain with space before top-level domain",
|
||||
domain: "space .example.com",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain with trailing space",
|
||||
domain: "example.com ",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
|
||||
@@ -203,7 +203,7 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
|
||||
NetworkID: "testNetworkId",
|
||||
Name: "testResourceId",
|
||||
Description: "description",
|
||||
Address: "invalid-address",
|
||||
Address: "-invalid",
|
||||
}
|
||||
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
@@ -227,9 +227,9 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
|
||||
resource := &types.NetworkResource{
|
||||
AccountID: "testAccountId",
|
||||
NetworkID: "testNetworkId",
|
||||
Name: "testResourceId",
|
||||
Name: "used-name",
|
||||
Description: "description",
|
||||
Address: "invalid-address",
|
||||
Address: "example.com",
|
||||
}
|
||||
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
@@ -166,8 +165,7 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix,
|
||||
return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||
}
|
||||
|
||||
domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
|
||||
if domainRegex.MatchString(address) {
|
||||
if _, err := nbDomain.ValidateDomains([]string{address}); err == nil {
|
||||
return Domain, address, netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -23,10 +23,12 @@ func TestGetResourceType(t *testing.T) {
|
||||
{"example.com", Domain, false, "example.com", netip.Prefix{}},
|
||||
{"*.example.com", Domain, false, "*.example.com", netip.Prefix{}},
|
||||
{"sub.example.com", Domain, false, "sub.example.com", netip.Prefix{}},
|
||||
{"example.x", Domain, false, "example.x", netip.Prefix{}},
|
||||
{"internal", Domain, false, "internal", netip.Prefix{}},
|
||||
// Invalid inputs
|
||||
{"invalid", "", true, "", netip.Prefix{}},
|
||||
{"1.1.1.1/abc", "", true, "", netip.Prefix{}},
|
||||
{"1234", "", true, "", netip.Prefix{}},
|
||||
{"-invalid.com", "", true, "", netip.Prefix{}},
|
||||
{"", "", true, "", netip.Prefix{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -728,11 +728,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
}
|
||||
|
||||
if temporary {
|
||||
// we should track ephemeral peers to be able to clean them if the peer don't sync and be marked as connected
|
||||
am.networkMapController.TrackEphemeralPeer(ctx, newPeer)
|
||||
}
|
||||
|
||||
if addedByUser {
|
||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
||||
if err != nil {
|
||||
@@ -760,6 +755,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if ephemeral {
|
||||
// we should track ephemeral peers to be able to clean them if the peer doesn't sync and isn't marked as connected
|
||||
am.networkMapController.TrackEphemeralPeer(ctx, newPeer)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||
return nil
|
||||
})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user