mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Compare commits
2 Commits
v0.64.2
...
log/conn-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97ad3307dd | ||
|
|
f6cc27d675 |
@@ -3,7 +3,15 @@ package android
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -76,21 +84,34 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
}
|
||||
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
|
||||
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
if !supportsSSO {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
return true, err
|
||||
}
|
||||
@@ -108,17 +129,19 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
||||
}
|
||||
|
||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
}
|
||||
return backoffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
@@ -137,41 +160,49 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
var needsLogin bool
|
||||
|
||||
// check if we need to generate JWT token
|
||||
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err, _ = authClient.Login(a.ctx, "", jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
|
||||
go urlOpener.OnLoginSuccess()
|
||||
if err == nil {
|
||||
go urlOpener.OnLoginSuccess()
|
||||
}
|
||||
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
@@ -181,10 +212,22 @@ func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||
return backoff.RetryNotify(
|
||||
bf,
|
||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -276,19 +277,18 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
needsLogin := false
|
||||
|
||||
err, isAuthError := authClient.Login(ctx, "", "")
|
||||
if isAuthError {
|
||||
needsLogin = true
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("login check failed: %v", err)
|
||||
err := WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, "", "")
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
@@ -300,9 +300,23 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err, _ = authClient.Login(ctx, setupKey, jwtToken)
|
||||
var lastError error
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
lastError = err
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if lastError != nil {
|
||||
return fmt.Errorf("login failed: %v", lastError)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -330,7 +344,11 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
defer c()
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
@@ -177,13 +176,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create auth client: %w", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -83,10 +83,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChain(); err != nil {
|
||||
return fmt.Errorf("init notrack chain: %w", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
@@ -181,10 +177,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||
}
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
@@ -285,125 +277,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
tableRaw = "raw"
|
||||
)
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||
//
|
||||
// Traffic flows that need NOTRACK:
|
||||
//
|
||||
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||
// Matched by: sport=wgPort
|
||||
//
|
||||
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 3. Ingress: Packets to WireGuard
|
||||
// dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||
// dst=127.0.0.1:proxyPort
|
||||
// Matched by: dport=proxyPort
|
||||
//
|
||||
// Rules are cleaned up when the firewall manager is closed.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
wgPortStr := fmt.Sprintf("%d", wgPort)
|
||||
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
||||
|
||||
// Egress rules: match outgoing loopback UDP packets
|
||||
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
|
||||
return fmt.Errorf("add output sport notrack rule: %w", err)
|
||||
}
|
||||
|
||||
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
|
||||
return fmt.Errorf("add output dport notrack rule: %w", err)
|
||||
}
|
||||
|
||||
// Ingress rules: match incoming loopback UDP packets
|
||||
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
|
||||
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
|
||||
}
|
||||
|
||||
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
|
||||
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) initNoTrackChain() error {
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
log.Debugf("cleanup notrack chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
|
||||
return fmt.Errorf("create chain: %w", err)
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
|
||||
log.Debugf("delete output jump rule: %v", delErr)
|
||||
}
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNoTrackChain() error {
|
||||
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
|
||||
return fmt.Errorf("clear and delete chain: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -168,10 +168,6 @@ type Manager interface {
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -49,10 +48,8 @@ type Manager struct {
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
notrackOutputChain *nftables.Chain
|
||||
notrackPreroutingChain *nftables.Chain
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
@@ -94,10 +91,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChains(workTable); err != nil {
|
||||
return fmt.Errorf("init notrack chains: %w", err)
|
||||
}
|
||||
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
@@ -295,15 +288,7 @@ func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.aclManager.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.refreshNoTrackChains(); err != nil {
|
||||
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return m.aclManager.Flush()
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
@@ -346,176 +331,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRawOutput = "netbird-raw-out"
|
||||
chainNameRawPrerouting = "netbird-raw-pre"
|
||||
)
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||
//
|
||||
// Traffic flows that need NOTRACK:
|
||||
//
|
||||
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||
// Matched by: sport=wgPort
|
||||
//
|
||||
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 3. Ingress: Packets to WireGuard
|
||||
// dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||
// dst=127.0.0.1:proxyPort
|
||||
// Matched by: dport=proxyPort
|
||||
//
|
||||
// Rules are cleaned up when the firewall manager is closed.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
|
||||
return fmt.Errorf("notrack chains not initialized")
|
||||
}
|
||||
|
||||
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
|
||||
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
|
||||
loopback := []byte{127, 0, 0, 1}
|
||||
|
||||
// Egress rules: match outgoing loopback UDP packets
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackOutputChain.Table,
|
||||
Chain: m.notrackOutputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackOutputChain.Table,
|
||||
Chain: m.notrackOutputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
|
||||
// Ingress rules: match incoming loopback UDP packets
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackPreroutingChain.Table,
|
||||
Chain: m.notrackPreroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackPreroutingChain.Table,
|
||||
Chain: m.notrackPreroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush notrack rules: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
|
||||
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawOutput,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityRaw,
|
||||
})
|
||||
|
||||
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawPrerouting,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityRaw,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush chain creation: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) refreshNoTrackChains() error {
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
for _, c := range chains {
|
||||
if c.Table.Name != tableName {
|
||||
continue
|
||||
}
|
||||
switch c.Name {
|
||||
case chainNameRawOutput:
|
||||
m.notrackOutputChain = c
|
||||
case chainNameRawPrerouting:
|
||||
m.notrackPreroutingChain = c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -570,14 +570,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
|
||||
}
|
||||
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
|
||||
@@ -50,7 +50,6 @@ func ValidateMTU(mtu uint16) error {
|
||||
|
||||
type wgProxyFactory interface {
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetProxyPort() uint16
|
||||
Free() error
|
||||
}
|
||||
|
||||
@@ -81,12 +80,6 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
|
||||
return w.wgProxyFactory.GetProxy()
|
||||
}
|
||||
|
||||
// GetProxyPort returns the proxy port used by the WireGuard proxy.
|
||||
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
|
||||
func (w *WGIface) GetProxyPort() uint16 {
|
||||
return w.wgProxyFactory.GetProxyPort()
|
||||
}
|
||||
|
||||
// GetBind returns the EndpointManager userspace bind mode.
|
||||
func (w *WGIface) GetBind() device.EndpointManager {
|
||||
w.mu.Lock()
|
||||
|
||||
@@ -114,21 +114,34 @@ func (p *ProxyBind) Pause() {
|
||||
}
|
||||
|
||||
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||
ep, err := addrToEndpoint(endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to start package redirection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgCurrentUsed = ep
|
||||
ep, err := addrToEndpoint(endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to convert endpoint address: %v", err)
|
||||
} else {
|
||||
p.wgCurrentUsed = ep
|
||||
}
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
||||
if addr == nil {
|
||||
return nil, errors.New("nil address")
|
||||
}
|
||||
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
|
||||
func (p *ProxyBind) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
@@ -212,16 +225,3 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||
return &netipAddr, nil
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
||||
if addr == nil {
|
||||
return nil, fmt.Errorf("invalid address")
|
||||
}
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
|
||||
@@ -27,19 +27,12 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
localHostNetIPv4 = net.ParseIP("127.0.0.1")
|
||||
localHostNetIPv6 = net.ParseIP("::1")
|
||||
|
||||
serializeOpts = gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
localHostNetIP = net.ParseIP("127.0.0.1")
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
localWGListenPort int
|
||||
proxyPort int
|
||||
mtu uint16
|
||||
|
||||
ebpfManager ebpfMgr.Manager
|
||||
@@ -47,8 +40,7 @@ type WGEBPFProxy struct {
|
||||
turnConnMutex sync.Mutex
|
||||
|
||||
lastUsedPort uint16
|
||||
rawConnIPv4 net.PacketConn
|
||||
rawConnIPv6 net.PacketConn
|
||||
rawConn net.PacketConn
|
||||
conn transport.UDPConn
|
||||
|
||||
ctx context.Context
|
||||
@@ -70,39 +62,23 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
|
||||
// Listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) Listen() error {
|
||||
pl := portLookup{}
|
||||
proxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.proxyPort = proxyPort
|
||||
|
||||
// Prepare IPv4 raw socket (required)
|
||||
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||
wgPorxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Prepare IPv6 raw socket (optional)
|
||||
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
||||
if err != nil {
|
||||
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
|
||||
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
|
||||
if err != nil {
|
||||
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
|
||||
}
|
||||
if p.rawConnIPv6 != nil {
|
||||
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
addr := net.UDPAddr{
|
||||
Port: proxyPort,
|
||||
Port: wgPorxyPort,
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
}
|
||||
|
||||
@@ -118,7 +94,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
p.conn = conn
|
||||
|
||||
go p.proxyToRemote()
|
||||
log.Infof("local wg proxy listening on: %d", proxyPort)
|
||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -159,25 +135,12 @@ func (p *WGEBPFProxy) Free() error {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
if p.rawConnIPv4 != nil {
|
||||
if err := p.rawConnIPv4.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
if p.rawConnIPv6 != nil {
|
||||
if err := p.rawConnIPv6.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := p.rawConn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// GetProxyPort returns the proxy listening port.
|
||||
func (p *WGEBPFProxy) GetProxyPort() uint16 {
|
||||
return uint16(p.proxyPort)
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
// From this go routine has only one instance.
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
@@ -255,60 +218,31 @@ generatePort:
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
||||
|
||||
var ipH gopacket.SerializableLayer
|
||||
var networkLayer gopacket.NetworkLayer
|
||||
var dstIP net.IP
|
||||
var rawConn net.PacketConn
|
||||
|
||||
if endpointAddr.IP.To4() != nil {
|
||||
// IPv4 path
|
||||
ipv4 := &layers.IPv4{
|
||||
DstIP: localHostNetIPv4,
|
||||
SrcIP: endpointAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv4
|
||||
networkLayer = ipv4
|
||||
dstIP = localHostNetIPv4
|
||||
rawConn = p.rawConnIPv4
|
||||
} else {
|
||||
// IPv6 path
|
||||
if p.rawConnIPv6 == nil {
|
||||
return fmt.Errorf("IPv6 raw socket not available")
|
||||
}
|
||||
ipv6 := &layers.IPv6{
|
||||
DstIP: localHostNetIPv6,
|
||||
SrcIP: endpointAddr.IP,
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv6
|
||||
networkLayer = ipv6
|
||||
dstIP = localHostNetIPv6
|
||||
rawConn = p.rawConnIPv6
|
||||
payload := gopacket.Payload(data)
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: localHostNetIP,
|
||||
SrcIP: endpointAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(endpointAddr.Port),
|
||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||
}
|
||||
|
||||
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
layerBuffer := gopacket.NewSerializeBuffer()
|
||||
payload := gopacket.Payload(data)
|
||||
|
||||
if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil {
|
||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
|
||||
if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil {
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -41,7 +41,7 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
}
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
@@ -91,14 +91,12 @@ func (p *ProxyWrapper) Pause() {
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
if endpoint == nil || endpoint.IP == nil {
|
||||
log.Errorf("failed to start package redirection, endpoint is nil")
|
||||
return
|
||||
}
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = endpoint
|
||||
if endpoint != nil && endpoint.IP != nil {
|
||||
p.wgEndpointCurrentUsedAddr = endpoint
|
||||
}
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
@@ -54,14 +54,6 @@ func (w *KernelFactory) GetProxy() Proxy {
|
||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||
}
|
||||
|
||||
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
|
||||
func (w *KernelFactory) GetProxyPort() uint16 {
|
||||
if w.ebpfProxy == nil {
|
||||
return 0
|
||||
}
|
||||
return w.ebpfProxy.GetProxyPort()
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
if w.ebpfProxy == nil {
|
||||
return nil
|
||||
|
||||
@@ -24,11 +24,6 @@ func (w *USPFactory) GetProxy() Proxy {
|
||||
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
||||
}
|
||||
|
||||
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
|
||||
func (w *USPFactory) GetProxyPort() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (w *USPFactory) Free() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,87 +8,43 @@ import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
|
||||
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
|
||||
return prepareSenderRawSocket(syscall.AF_INET, true)
|
||||
}
|
||||
|
||||
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
|
||||
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
|
||||
return prepareSenderRawSocket(syscall.AF_INET6, false)
|
||||
}
|
||||
|
||||
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
|
||||
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
|
||||
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
|
||||
if isIPv4 {
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
|
||||
}
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
if closeErr := file.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
// Close the original file to release the FD (net.FilePacketConn duplicates it)
|
||||
if closeErr := file.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
@@ -1,353 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
|
||||
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
|
||||
func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
||||
udpAddr1, ok1 := addr1.(*net.UDPAddr)
|
||||
udpAddr2, ok2 := addr2.(*net.UDPAddr)
|
||||
|
||||
if !ok1 || !ok2 {
|
||||
return addr1.String() == addr2.String()
|
||||
}
|
||||
|
||||
// Compare IP and Port, ignoring zone
|
||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||
wgPort := 51852
|
||||
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
|
||||
func TestRedirectAs_UDP_IPv6(t *testing.T) {
|
||||
wgPort := 51853
|
||||
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// testRedirectAs is a helper function that tests the RedirectAs functionality
|
||||
// It verifies that:
|
||||
// 1. Initial traffic from relay connection works
|
||||
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
|
||||
// 3. Multiple packets are correctly redirected with the new source address
|
||||
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
|
||||
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
|
||||
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener4.Close()
|
||||
|
||||
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
|
||||
IP: net.ParseIP("::1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener6.Close()
|
||||
|
||||
// Determine which listener to use based on the NetBird address IP version
|
||||
// (this is where initial traffic will come from before RedirectAs is called)
|
||||
var wgListener *net.UDPConn
|
||||
if p2pEndpoint.IP.To4() == nil {
|
||||
wgListener = wgListener6
|
||||
} else {
|
||||
wgListener = wgListener4
|
||||
}
|
||||
|
||||
// Create relay server and connection
|
||||
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0, // Random port
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay server: %v", err)
|
||||
}
|
||||
defer relayServer.Close()
|
||||
|
||||
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay connection: %v", err)
|
||||
}
|
||||
defer relayConn.Close()
|
||||
|
||||
// Add TURN connection to proxy
|
||||
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||
t.Fatalf("failed to add TURN connection: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("failed to close proxy connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start the proxy
|
||||
proxy.Work()
|
||||
|
||||
// Phase 1: Test initial relay traffic
|
||||
msgFromRelay := []byte("hello from relay")
|
||||
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write to relay server: %v", err)
|
||||
}
|
||||
|
||||
// Set read deadline to avoid hanging
|
||||
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, _, err := wgListener4.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from WireGuard listener: %v", err)
|
||||
}
|
||||
|
||||
if n != len(msgFromRelay) {
|
||||
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(msgFromRelay) {
|
||||
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
|
||||
}
|
||||
|
||||
// Phase 2: Redirect to p2p endpoint
|
||||
proxy.RedirectAs(p2pEndpoint)
|
||||
|
||||
// Give the proxy a moment to process the redirect
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Phase 3: Test redirected traffic
|
||||
redirectedMessages := [][]byte{
|
||||
[]byte("redirected message 1"),
|
||||
[]byte("redirected message 2"),
|
||||
[]byte("redirected message 3"),
|
||||
}
|
||||
|
||||
for i, msg := range redirectedMessages {
|
||||
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
|
||||
}
|
||||
|
||||
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
|
||||
}
|
||||
|
||||
// Verify message content
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
|
||||
}
|
||||
|
||||
// Verify source address matches p2p endpoint (this is the key test)
|
||||
// Use compareUDPAddr to ignore IPv6 zone IDs
|
||||
if !compareUDPAddr(srcAddr, p2pEndpoint) {
|
||||
t.Errorf("message %d: expected source address %s, got %s",
|
||||
i+1, p2pEndpoint.String(), srcAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||
wgPort := 51856
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create WireGuard listener
|
||||
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener.Close()
|
||||
|
||||
// Create relay server and connection
|
||||
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay server: %v", err)
|
||||
}
|
||||
defer relayServer.Close()
|
||||
|
||||
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay connection: %v", err)
|
||||
}
|
||||
defer relayConn.Close()
|
||||
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||
t.Fatalf("failed to add TURN connection: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("failed to close proxy connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy.Work()
|
||||
|
||||
// Test switching between multiple endpoints - using addresses in local subnet
|
||||
endpoints := []*net.UDPAddr{
|
||||
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
|
||||
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
|
||||
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
|
||||
}
|
||||
|
||||
for i, endpoint := range endpoints {
|
||||
proxy.RedirectAs(endpoint)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
msg := []byte("test message")
|
||||
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
|
||||
}
|
||||
|
||||
if !compareUDPAddr(srcAddr, endpoint) {
|
||||
t.Errorf("endpoint %d: expected source %s, got %s",
|
||||
i, endpoint.String(), srcAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
dialer := net.Dialer{}
|
||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
|
||||
@@ -19,56 +19,37 @@ var (
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
localHostNetIPAddrV4 = &net.IPAddr{
|
||||
localHostNetIPAddr = &net.IPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
localHostNetIPAddrV6 = &net.IPAddr{
|
||||
IP: net.ParseIP("::1"),
|
||||
}
|
||||
)
|
||||
|
||||
type SrcFaker struct {
|
||||
srcAddr *net.UDPAddr
|
||||
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
localHostAddr *net.IPAddr
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
}
|
||||
|
||||
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||
// Create only the raw socket for the address family we need
|
||||
var rawSocket net.PacketConn
|
||||
var err error
|
||||
var localHostAddr *net.IPAddr
|
||||
|
||||
if srcAddr.IP.To4() != nil {
|
||||
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||
localHostAddr = localHostNetIPAddrV4
|
||||
} else {
|
||||
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||
localHostAddr = localHostNetIPAddrV6
|
||||
}
|
||||
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||
if err != nil {
|
||||
if closeErr := rawSocket.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket: %v", closeErr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := &SrcFaker{
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
localHostAddr: localHostAddr,
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
}
|
||||
|
||||
return f, nil
|
||||
@@ -91,7 +72,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
@@ -99,40 +80,19 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||
var ipH gopacket.SerializableLayer
|
||||
var networkLayer gopacket.NetworkLayer
|
||||
|
||||
// Check if source IP is IPv4 or IPv6
|
||||
if srcAddr.IP.To4() != nil {
|
||||
// IPv4
|
||||
ipv4 := &layers.IPv4{
|
||||
DstIP: localHostNetIPAddrV4.IP,
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv4
|
||||
networkLayer = ipv4
|
||||
} else {
|
||||
// IPv6
|
||||
ipv6 := &layers.IPv6{
|
||||
DstIP: localHostNetIPAddrV6.IP,
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv6
|
||||
networkLayer = ipv6
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: net.ParseIP("127.0.0.1"),
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(networkLayer)
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,499 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// Auth manages authentication operations with the management server
|
||||
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||
type Auth struct {
|
||||
mutex sync.RWMutex
|
||||
client *mgm.GrpcClient
|
||||
config *profilemanager.Config
|
||||
privateKey wgtypes.Key
|
||||
mgmURL *url.URL
|
||||
mgmTLSEnabled bool
|
||||
}
|
||||
|
||||
// NewAuth creates a new Auth instance that manages authentication flows
|
||||
// It establishes a connection to the management server that will be reused for all operations
|
||||
// The connection is automatically recreated with backoff if it becomes disconnected
|
||||
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
||||
// Validate WireGuard private key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine TLS setting based on URL scheme
|
||||
mgmTLSEnabled := mgmURL.Scheme == "https"
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
return &Auth{
|
||||
client: mgmClient,
|
||||
config: config,
|
||||
privateKey: myPrivateKey,
|
||||
mgmURL: mgmURL,
|
||||
mgmTLSEnabled: mgmTLSEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the management client connection
|
||||
func (a *Auth) Close() error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
if a.client == nil {
|
||||
return nil
|
||||
}
|
||||
return a.client.Close()
|
||||
}
|
||||
|
||||
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
||||
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
||||
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
||||
var supportsSSO bool
|
||||
|
||||
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
// Try PKCE flow first
|
||||
_, err := a.getPKCEFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if PKCE is not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// PKCE not supported, try Device flow
|
||||
_, err = a.getDeviceFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if Device flow is also not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// Neither PKCE nor Device flow is supported
|
||||
supportsSSO = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Device flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
}
|
||||
|
||||
// PKCE flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
})
|
||||
|
||||
return supportsSSO, err
|
||||
}
|
||||
|
||||
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
|
||||
// This avoids creating a new connection to the management server
|
||||
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
|
||||
var flow OAuthFlow
|
||||
var err error
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
if forceDeviceAuth {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
|
||||
// Try PKCE flow first
|
||||
flow, err = a.getPKCEFlow(client)
|
||||
if err != nil {
|
||||
// If PKCE not supported, try Device flow
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return flow, err
|
||||
}
|
||||
|
||||
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
needsLogin = false
|
||||
return err
|
||||
})
|
||||
|
||||
return needsLogin, err
|
||||
}
|
||||
|
||||
// Login attempts to log in or register the client with the management server
|
||||
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
|
||||
isAuthError = false
|
||||
return nil
|
||||
})
|
||||
|
||||
return err, isAuthError
|
||||
}
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &PKCEAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
RedirectURLs: protoConfig.GetRedirectURLs(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
ClientCertPair: a.config.ClientCertKeyPair,
|
||||
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
||||
}
|
||||
|
||||
if err := validatePKCEConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewPKCEAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &DeviceAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
Domain: protoConfig.Domain,
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
}
|
||||
|
||||
// Keep compatibility with older management versions
|
||||
if config.Scope == "" {
|
||||
config.Scope = "openid"
|
||||
}
|
||||
|
||||
if err := validateDeviceAuthConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewDeviceAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("peer has been successfully registered on Management Service")
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// setSystemInfoFlags sets all configuration flags on the provided system info
|
||||
func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
info.SetFlags(
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
a.config.DisableFirewall,
|
||||
a.config.BlockLANAccess,
|
||||
a.config.BlockInbound,
|
||||
a.config.LazyConnectionEnabled,
|
||||
a.config.EnableSSHRoot,
|
||||
a.config.EnableSSHSFTP,
|
||||
a.config.EnableSSHLocalPortForwarding,
|
||||
a.config.EnableSSHRemotePortForwarding,
|
||||
a.config.DisableSSHAuth,
|
||||
)
|
||||
}
|
||||
|
||||
// reconnect closes the current connection and creates a new one
|
||||
// It checks if the brokenClient is still the current client before reconnecting
|
||||
// to avoid multiple threads reconnecting unnecessarily
|
||||
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
// Double-check: if client has already been replaced by another thread, skip reconnection
|
||||
if a.client != brokenClient {
|
||||
log.Debugf("client already reconnected by another thread, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create new connection FIRST, before closing the old one
|
||||
// This ensures a.client is never nil, preventing panics in other threads
|
||||
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
|
||||
// Keep the old client if reconnection fails
|
||||
return err
|
||||
}
|
||||
|
||||
// Close old connection AFTER new one is successfully created
|
||||
oldClient := a.client
|
||||
a.client = mgmClient
|
||||
|
||||
if oldClient != nil {
|
||||
if err := oldClient.Close(); err != nil {
|
||||
log.Debugf("error closing old connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// These error codes indicate connection issues
|
||||
return s.Code() == codes.Unavailable ||
|
||||
s.Code() == codes.DeadlineExceeded ||
|
||||
s.Code() == codes.Canceled ||
|
||||
s.Code() == codes.Internal
|
||||
}
|
||||
|
||||
// withRetry wraps an operation with exponential backoff retry logic
|
||||
// It automatically reconnects on connection errors
|
||||
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
|
||||
backoffSettings := &backoff.ExponentialBackOff{
|
||||
InitialInterval: 500 * time.Millisecond,
|
||||
RandomizationFactor: 0.5,
|
||||
Multiplier: 1.5,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 2 * time.Minute,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
backoffSettings.Reset()
|
||||
|
||||
return backoff.RetryNotify(
|
||||
func() error {
|
||||
// Capture the client BEFORE the operation to ensure we track the correct client
|
||||
a.mutex.RLock()
|
||||
currentClient := a.client
|
||||
a.mutex.RUnlock()
|
||||
|
||||
if currentClient == nil {
|
||||
return status.Errorf(codes.Unavailable, "client is not initialized")
|
||||
}
|
||||
|
||||
// Execute operation with the captured client
|
||||
err := operation(currentClient)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If it's a connection error, attempt reconnection using the client that was actually used
|
||||
if isConnectionError(err) {
|
||||
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
||||
|
||||
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
|
||||
log.Errorf("reconnection failed: %v", reconnectErr)
|
||||
return reconnectErr
|
||||
}
|
||||
// Return the original error to trigger retry with the new connection
|
||||
return err
|
||||
}
|
||||
|
||||
// For authentication errors, don't retry
|
||||
if isAuthenticationError(err) {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
backoff.WithContext(backoffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
|
||||
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
|
||||
func isAuthenticationError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
// isPermissionDenied checks if the error is a PermissionDenied error.
|
||||
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
|
||||
func isPermissionDenied(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
return isAuthenticationError(err)
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
return isPermissionDenied(err)
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
@@ -25,56 +26,12 @@ const (
|
||||
|
||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validateDeviceAuthConfig validates device authorization provider configuration
|
||||
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||
// for the Device Authorization Flow.
|
||||
type DeviceAuthorizationFlow struct {
|
||||
providerConfig DeviceAuthProviderConfig
|
||||
HTTPClient HTTPClient
|
||||
providerConfig internal.DeviceAuthProviderConfig
|
||||
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
|
||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||
@@ -100,7 +57,7 @@ type TokenRequestResponse struct {
|
||||
}
|
||||
|
||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -132,11 +89,6 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
||||
return d.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the device authorization flow
|
||||
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
|
||||
d.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||
form := url.Values{}
|
||||
@@ -247,22 +199,14 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
|
||||
}
|
||||
|
||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning.
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
interval := time.Duration(info.Interval) * time.Second
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-ticker.C:
|
||||
|
||||
tokenResponse, err := d.requestToken(info)
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
@@ -113,19 +115,18 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
err: testCase.inputReqError,
|
||||
}
|
||||
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
deviceFlow := &DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
@@ -279,19 +280,18 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
countResBody: testCase.inputCountResBody,
|
||||
}
|
||||
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
deviceFlow := DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -86,33 +87,19 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
if hint != "" {
|
||||
pkceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return pkceFlowInfo, nil
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
case ok && s.Code() == codes.NotFound:
|
||||
@@ -127,9 +114,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
if hint != "" {
|
||||
deviceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return deviceFlowInfo, nil
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/templates"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
@@ -34,67 +35,17 @@ const (
|
||||
defaultPKCETimeoutSeconds = 300
|
||||
)
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validatePKCEConfig validates PKCE provider configuration
|
||||
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||
// the Authorization Code Flow with PKCE.
|
||||
type PKCEAuthorizationFlow struct {
|
||||
providerConfig PKCEAuthProviderConfig
|
||||
providerConfig internal.PKCEAuthProviderConfig
|
||||
state string
|
||||
codeVerifier string
|
||||
oAuthConfig *oauth2.Config
|
||||
}
|
||||
|
||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
var availableRedirectURL string
|
||||
|
||||
excludedRanges := getSystemExcludedPortRanges()
|
||||
@@ -173,21 +124,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the PKCE authorization flow
|
||||
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
|
||||
p.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
||||
tokenChan := make(chan *oauth2.Token, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
@@ -198,7 +138,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
|
||||
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
@@ -209,8 +149,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
|
||||
go p.startServer(server, tokenChan, errChan)
|
||||
|
||||
select {
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case token := <-tokenChan:
|
||||
return p.parseOAuthToken(token)
|
||||
case err := <-errChan:
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
@@ -49,7 +50,7 @@ func TestPromptLogin(t *testing.T) {
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := PKCEAuthProviderConfig{
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestParseExcludedPortRanges(t *testing.T) {
|
||||
@@ -93,7 +95,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||
|
||||
availablePort := 65432
|
||||
|
||||
config := PKCEAuthProviderConfig{
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
136
client/internal/device_auth.go
Normal file
136
client/internal/device_auth.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
type DeviceAuthorizationFlow struct {
|
||||
Provider string
|
||||
ProviderConfig DeviceAuthProviderConfig
|
||||
}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: DeviceAuthProviderConfig{
|
||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
},
|
||||
}
|
||||
|
||||
// keep compatibility with older management versions
|
||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||
}
|
||||
|
||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return deviceAuthorizationFlow, nil
|
||||
}
|
||||
|
||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -505,10 +505,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
// Set up notrack rules immediately after proxy is listening to prevent
|
||||
// conntrack entries from being created before the rules are in place
|
||||
e.setupWGProxyNoTrack()
|
||||
|
||||
// Set the WireGuard interface for rosenpass after interface is up
|
||||
if e.rpManager != nil {
|
||||
e.rpManager.SetInterface(e.wgInterface)
|
||||
@@ -621,23 +617,6 @@ func (e *Engine) initFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic.
|
||||
// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy.
|
||||
func (e *Engine) setupWGProxyNoTrack() {
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
|
||||
proxyPort := e.wgInterface.GetProxyPort()
|
||||
if proxyPort == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil {
|
||||
log.Warnf("failed to setup ebpf proxy notrack: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) blockLanAccess() {
|
||||
if e.config.BlockInbound {
|
||||
// no need to set up extra deny rules if inbound is already blocked in general
|
||||
@@ -1665,7 +1644,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
|
||||
func (e *Engine) close() {
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
|
||||
@@ -107,7 +107,6 @@ type MockWGIface struct {
|
||||
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetProxyPortFunc func() uint16
|
||||
GetNetFunc func() *netstack.Net
|
||||
LastActivitiesFunc func() map[string]monotime.Time
|
||||
}
|
||||
@@ -204,13 +203,6 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
return m.GetProxyFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxyPort() uint16 {
|
||||
if m.GetProxyPortFunc != nil {
|
||||
return m.GetProxyPortFunc()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetNet() *netstack.Net {
|
||||
return m.GetNetFunc()
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ type wgIfaceBase interface {
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetProxyPort() uint16
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemoveEndpointAddress(key string) error
|
||||
RemovePeer(peerKey string) error
|
||||
|
||||
201
client/internal/login.go
Normal file
201
client/internal/login.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// IsLoginRequired check that the server is support SSO or not
|
||||
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
|
||||
mgmURL := config.ManagementURL
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||
if isLoginNeeded(err) {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Login or register the client
|
||||
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
cStatus, ok := status.FromError(err)
|
||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
return mgmClient, err
|
||||
}
|
||||
|
||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
info.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("peer has been successfully registered on Management Service")
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.Code() == codes.PermissionDenied {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
138
client/internal/pkce_auth.go
Normal file
138
client/internal/pkce_auth.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
||||
type PKCEAuthorizationFlow struct {
|
||||
ProviderConfig PKCEAuthProviderConfig
|
||||
}
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
authFlow := PKCEAuthorizationFlow{
|
||||
ProviderConfig: PKCEAuthProviderConfig{
|
||||
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
ClientCertPair: clientCert,
|
||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
|
||||
},
|
||||
}
|
||||
|
||||
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return PKCEAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return authFlow, nil
|
||||
}
|
||||
|
||||
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -263,14 +263,7 @@ func (c *Client) IsLoginRequired() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("IsLoginRequired: failed to create auth client: %v", err)
|
||||
return true // Assume login is required if we can't create auth client
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("IsLoginRequired: check failed: %v", err)
|
||||
// If the check fails, assume login is required to be safe
|
||||
@@ -321,19 +314,16 @@ func (c *Client) LoginForMobile() string {
|
||||
|
||||
// This could cause a potential race condition with loading the extension which need to be handled on swift side
|
||||
go func() {
|
||||
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
|
||||
return
|
||||
}
|
||||
jwtToken := tokenInfo.GetTokenToUse()
|
||||
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("LoginForMobile: failed to create auth client: %v", err)
|
||||
return
|
||||
}
|
||||
defer authClient.Close()
|
||||
if err, _ := authClient.Login(ctx, "", jwtToken); err != nil {
|
||||
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
|
||||
log.Errorf("LoginForMobile: Login failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -7,8 +7,13 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -85,21 +90,34 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
}
|
||||
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
|
||||
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
if !supportsSSO {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||
// which are blocked by the tvOS sandbox in App Group containers
|
||||
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
||||
@@ -123,17 +141,19 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
||||
}
|
||||
|
||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
}
|
||||
return backoffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||
@@ -144,16 +164,15 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
||||
// LoginSync performs a synchronous login check without UI interaction
|
||||
// Used for background VPN connection where user should already be authenticated
|
||||
func (a *Auth) LoginSync() error {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
var needsLogin bool
|
||||
|
||||
// check if we need to generate JWT token
|
||||
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
@@ -161,12 +180,15 @@ func (a *Auth) LoginSync() error {
|
||||
return fmt.Errorf("not authenticated")
|
||||
}
|
||||
|
||||
err, isAuthError := authClient.Login(a.ctx, "", jwtToken)
|
||||
if err != nil {
|
||||
if isAuthError {
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return fmt.Errorf("authentication error: %v", err)
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -203,6 +225,8 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
|
||||
var needsLogin bool
|
||||
|
||||
// Create context with device name if provided
|
||||
ctx := a.ctx
|
||||
if deviceName != "" {
|
||||
@@ -210,33 +234,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
||||
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
}
|
||||
|
||||
authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
// check if we need to generate JWT token
|
||||
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||
err := a.withBackOff(ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth)
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err, isAuthError := authClient.Login(ctx, "", jwtToken)
|
||||
if err != nil {
|
||||
if isAuthError {
|
||||
err = a.withBackOff(ctx, func() error {
|
||||
err := internal.Login(ctx, a.config, "", jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// PermissionDenied means registration is required or peer is blocked
|
||||
return fmt.Errorf("authentication error: %v", err)
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -261,10 +285,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
||||
|
||||
const authInfoRequestTimeout = 30 * time.Second
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth)
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
||||
@@ -289,6 +313,15 @@ func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||
return backoff.RetryNotify(
|
||||
bf,
|
||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||
})
|
||||
}
|
||||
|
||||
// GetConfigJSON returns the current config as a JSON string.
|
||||
// This can be used by the caller to persist the config via alternative storage
|
||||
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||
|
||||
@@ -253,17 +253,10 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
|
||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create auth client: %v", err)
|
||||
return internal.StatusLoginFailed, err
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
var status internal.StatusType
|
||||
err, isAuthError := authClient.Login(ctx, setupKey, jwtToken)
|
||||
err := internal.Login(ctx, s.config, setupKey, jwtToken)
|
||||
if err != nil {
|
||||
if isAuthError {
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
log.Warnf("failed login: %v", err)
|
||||
status = internal.StatusNeedsLogin
|
||||
} else {
|
||||
@@ -588,7 +581,8 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
s.oauthAuthFlow.waitCancel()
|
||||
}
|
||||
|
||||
waitCTX, cancel := context.WithCancel(ctx)
|
||||
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
|
||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
||||
defer cancel()
|
||||
|
||||
s.mutex.Lock()
|
||||
|
||||
@@ -207,6 +207,8 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
||||
}
|
||||
|
||||
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
|
||||
// Create a backend session to mirror the client's session request.
|
||||
// This keeps the connection alive on the server side while port forwarding channels operate.
|
||||
serverSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||
@@ -214,28 +216,10 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c
|
||||
}
|
||||
defer func() { _ = serverSession.Close() }()
|
||||
|
||||
serverSession.Stdin = session
|
||||
serverSession.Stdout = session
|
||||
serverSession.Stderr = session.Stderr()
|
||||
<-session.Context().Done()
|
||||
|
||||
if err := serverSession.Shell(); err != nil {
|
||||
log.Debugf("start shell: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- serverSession.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-session.Context().Done():
|
||||
return
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
log.Debugf("shell session: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
if err := session.Exit(0); err != nil {
|
||||
log.Debugf("session exit: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handleExecution executes an SSH command or shell with privilege validation
|
||||
func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
|
||||
// handleCommand executes an SSH command with privilege validation
|
||||
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
|
||||
hasPty := winCh != nil
|
||||
|
||||
commandType := "command"
|
||||
@@ -23,7 +23,7 @@ func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privile
|
||||
|
||||
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
||||
|
||||
execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty)
|
||||
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
|
||||
if err != nil {
|
||||
logger.Errorf("%s creation failed: %v", commandType, err)
|
||||
|
||||
@@ -51,12 +51,13 @@ func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privile
|
||||
|
||||
defer cleanup()
|
||||
|
||||
ptyReq, _, _ := session.Pty()
|
||||
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
||||
logger.Debugf("%s execution completed", commandType)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
return nil, nil, errors.New("no user in privilege result")
|
||||
@@ -65,28 +66,28 @@ func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheck
|
||||
// If PTY requested but su doesn't support --pty, skip su and use executor
|
||||
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
||||
if hasPty && !s.suSupportsPty {
|
||||
logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
// Try su first for system integration (PAM/audit) when privileged
|
||||
cmd, err := s.createSuCommand(logger, session, localUser, hasPty)
|
||||
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
||||
if err != nil || privilegeResult.UsedFallback {
|
||||
logger.Debugf("su command failed, falling back to executor: %v", err)
|
||||
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||
log.Debugf("su command failed, falling back to executor: %v", err)
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, func() {}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,17 +15,17 @@ import (
|
||||
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
||||
|
||||
// createSuCommand is not supported on JS/WASM
|
||||
func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||
return nil, errNotSupported
|
||||
}
|
||||
|
||||
// createExecutorCommand is not supported on JS/WASM
|
||||
func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||
return nil, nil, errNotSupported
|
||||
}
|
||||
|
||||
// prepareCommandEnv is not supported on JS/WASM
|
||||
func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string {
|
||||
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -100,52 +99,40 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
|
||||
return isUtilLinux
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su - for privilege switching.
|
||||
func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
if err := validateUsername(localUser.Username); err != nil {
|
||||
return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching
|
||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
suPath, err := exec.LookPath("su")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("su command not available: %w", err)
|
||||
}
|
||||
|
||||
args := []string{"-"}
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("no command specified for su execution")
|
||||
}
|
||||
|
||||
args := []string{"-l"}
|
||||
if hasPty && s.suSupportsPty {
|
||||
args = append(args, "--pty")
|
||||
}
|
||||
args = append(args, localUser.Username)
|
||||
args = append(args, localUser.Username, "-c", command)
|
||||
|
||||
command := session.RawCommand()
|
||||
if command != "" {
|
||||
args = append(args, "-c", command)
|
||||
}
|
||||
|
||||
logger.Debugf("creating su command: %s %v", suPath, args)
|
||||
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string.
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
if cmdString == "" {
|
||||
return []string{shell}
|
||||
return []string{shell, "-l"}
|
||||
}
|
||||
return []string{shell, "-c", cmdString}
|
||||
}
|
||||
|
||||
// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname".
|
||||
func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, shell, args[1:]...)
|
||||
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||
return cmd
|
||||
return []string{shell, "-l", "-c", cmdString}
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
@@ -167,7 +154,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e
|
||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
||||
if err != nil {
|
||||
logger.Errorf("Pty command creation failed: %v", err)
|
||||
@@ -257,6 +244,11 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Debugf("session close error: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(session, ptmx); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||
logger.Warnf("Pty output copy error: %v", err)
|
||||
@@ -276,7 +268,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex
|
||||
case <-ctx.Done():
|
||||
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
||||
case err := <-done:
|
||||
s.handlePtyCommandCompletion(logger, session, ptyMgr, err)
|
||||
s.handlePtyCommandCompletion(logger, session, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,20 +296,17 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) {
|
||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
|
||||
if err != nil {
|
||||
logger.Debugf("Pty command execution failed: %v", err)
|
||||
s.handleSessionExit(session, err, logger)
|
||||
} else {
|
||||
logger.Debugf("Pty command completed successfully")
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Close PTY to unblock io.Copy goroutines
|
||||
if err := ptyMgr.Close(); err != nil {
|
||||
logger.Debugf("Pty close after completion: %v", err)
|
||||
// Normal completion
|
||||
logger.Debugf("Pty command completed successfully")
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,32 +20,32 @@ import (
|
||||
|
||||
// getUserEnvironment retrieves the Windows environment for the target user.
|
||||
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
||||
func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) {
|
||||
userToken, err := s.getUserToken(logger, username, domain)
|
||||
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
||||
userToken, err := s.getUserToken(username, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
logger.Debugf("close user token: %v", err)
|
||||
log.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s.getUserEnvironmentWithToken(logger, userToken, username, domain)
|
||||
return s.getUserEnvironmentWithToken(userToken, username, domain)
|
||||
}
|
||||
|
||||
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
||||
func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) {
|
||||
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
|
||||
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
||||
if err != nil {
|
||||
logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
||||
}
|
||||
|
||||
envMap := make(map[string]string)
|
||||
|
||||
if err := s.loadSystemEnvironment(envMap); err != nil {
|
||||
logger.Debugf("failed to load system environment from registry: %v", err)
|
||||
log.Debugf("failed to load system environment from registry: %v", err)
|
||||
}
|
||||
|
||||
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
||||
@@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken window
|
||||
}
|
||||
|
||||
// getUserToken creates a user token for the specified user.
|
||||
func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
token, err := privilegeDropper.createToken(username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
||||
@@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Windows
|
||||
func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
userEnv, err := s.getUserEnvironment(logger, username, domain)
|
||||
userEnv, err := s.getUserEnvironment(username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
@@ -267,16 +267,22 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
|
||||
return env
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
if privilegeResult.User == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
return false
|
||||
}
|
||||
|
||||
cmd := session.Command()
|
||||
shell := getUserShell(privilegeResult.User.Uid)
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
|
||||
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
|
||||
if len(cmd) == 0 {
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
} else {
|
||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
||||
}
|
||||
|
||||
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -288,6 +294,11 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
return []string{shell, "-Command", cmdString}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
||||
logger.Info("starting interactive shell")
|
||||
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
|
||||
}
|
||||
|
||||
type PtyExecutionRequest struct {
|
||||
Shell string
|
||||
Command string
|
||||
@@ -297,25 +308,25 @@ type PtyExecutionRequest struct {
|
||||
Domain string
|
||||
}
|
||||
|
||||
func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error {
|
||||
logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
|
||||
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
||||
|
||||
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
logger.Debugf("close user token: %v", err)
|
||||
log.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
server := &Server{}
|
||||
userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain)
|
||||
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
|
||||
if err != nil {
|
||||
logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||
userEnv = os.Environ()
|
||||
}
|
||||
|
||||
@@ -337,8 +348,8 @@ func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req
|
||||
Environment: userEnv,
|
||||
}
|
||||
|
||||
logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||
return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig)
|
||||
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
|
||||
}
|
||||
|
||||
func getUserHomeFromEnv(env []string) string {
|
||||
@@ -360,8 +371,10 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithField("pid", cmd.Process.Pid)
|
||||
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err)
|
||||
logger.Debugf("kill process failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -376,7 +389,21 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
||||
}
|
||||
|
||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
logger.Error("no command specified for PTY execution")
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
|
||||
}
|
||||
|
||||
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
|
||||
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
@@ -388,14 +415,14 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _
|
||||
|
||||
req := PtyExecutionRequest{
|
||||
Shell: shell,
|
||||
Command: session.RawCommand(),
|
||||
Command: command,
|
||||
Width: ptyReq.Window.Width,
|
||||
Height: ptyReq.Window.Height,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
}
|
||||
|
||||
if err := executePtyCommandWithUserToken(logger, session, req); err != nil {
|
||||
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
|
||||
logger.Errorf("ConPty execution failed: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
|
||||
@@ -4,15 +4,12 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -26,67 +23,25 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server
|
||||
// spawns an executor subprocess via os.Executable(). During tests, this invokes the test
|
||||
// binary with "ssh exec" args. We handle that here to properly execute commands and
|
||||
// propagate exit codes.
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
runTestExecutor()
|
||||
return
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// runTestExecutor emulates the netbird executor for tests.
|
||||
// Parses --shell and --cmd args, runs the command, and exits with the correct code.
|
||||
func runTestExecutor() {
|
||||
if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" {
|
||||
fmt.Fprintf(os.Stderr, "executor recursion detected\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Setenv("_NETBIRD_TEST_EXECUTOR", "1")
|
||||
|
||||
shell := "/bin/sh"
|
||||
var command string
|
||||
for i := 3; i < len(os.Args); i++ {
|
||||
switch os.Args[i] {
|
||||
case "--shell":
|
||||
if i+1 < len(os.Args) {
|
||||
shell = os.Args[i+1]
|
||||
i++
|
||||
}
|
||||
case "--cmd":
|
||||
if i+1 < len(os.Args) {
|
||||
command = os.Args[i+1]
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if command == "" {
|
||||
cmd = exec.Command(shell)
|
||||
} else {
|
||||
cmd = exec.Command(shell, "-c", command)
|
||||
}
|
||||
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
os.Exit(exitErr.ExitCode())
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
||||
func TestSSHServerCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
@@ -450,171 +405,6 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
|
||||
return createTempKeyFileFromBytes(t, privateKey)
|
||||
}
|
||||
|
||||
// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags)
|
||||
// This ensures our implementation matches OpenSSH behavior for:
|
||||
// - ssh host command (no PTY - default when no TTY)
|
||||
// - ssh -T host command (explicit no PTY)
|
||||
// - ssh -t host command (force PTY)
|
||||
// - ssh -T host (no PTY shell - our implementation)
|
||||
func TestSSHPtyModes(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH PTY mode tests in short mode")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
baseArgs := []string{
|
||||
"-i", clientKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-o", "BatchMode=yes",
|
||||
}
|
||||
|
||||
t.Run("command_default_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (default no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "no_pty_default")
|
||||
})
|
||||
|
||||
t.Run("command_explicit_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "explicit_no_pty")
|
||||
})
|
||||
|
||||
t.Run("command_force_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (-tt force PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "force_pty")
|
||||
})
|
||||
|
||||
t.Run("shell_explicit_no_pty", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host))
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed")
|
||||
|
||||
go func() {
|
||||
defer stdin.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_, err := stdin.Write([]byte("echo shell_no_pty_test\n"))
|
||||
assert.NoError(t, err, "write echo command")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_, err = stdin.Write([]byte("exit 0\n"))
|
||||
assert.NoError(t, err, "write exit command")
|
||||
}()
|
||||
|
||||
output, _ := io.ReadAll(stdout)
|
||||
err = cmd.Wait()
|
||||
|
||||
require.NoError(t, err, "Shell (-T no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "shell_no_pty_test")
|
||||
})
|
||||
|
||||
t.Run("exit_code_preserved_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
err := cmd.Run()
|
||||
require.Error(t, err, "Command should exit with non-zero")
|
||||
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||
assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T")
|
||||
})
|
||||
|
||||
t.Run("exit_code_preserved_with_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
err := cmd.Run()
|
||||
require.Error(t, err, "PTY command should exit with non-zero")
|
||||
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||
assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt")
|
||||
})
|
||||
|
||||
t.Run("stderr_works_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host),
|
||||
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
var stdout, stderr strings.Builder
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
require.NoError(t, cmd.Run(), "stderr test failed")
|
||||
assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg")
|
||||
assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg")
|
||||
assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg")
|
||||
})
|
||||
|
||||
t.Run("stderr_merged_with_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host),
|
||||
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "PTY stderr test failed: %s", output)
|
||||
assert.Contains(t, string(output), "stdout_msg")
|
||||
assert.Contains(t, string(output), "stderr_msg")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
||||
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
@@ -36,35 +35,11 @@ type ExecutorConfig struct {
|
||||
}
|
||||
|
||||
// PrivilegeDropper handles secure privilege dropping in child processes
|
||||
type PrivilegeDropper struct {
|
||||
logger *log.Entry
|
||||
}
|
||||
|
||||
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||
type PrivilegeDropper struct{}
|
||||
|
||||
// NewPrivilegeDropper creates a new privilege dropper
|
||||
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||
pd := &PrivilegeDropper{}
|
||||
for _, opt := range opts {
|
||||
opt(pd)
|
||||
}
|
||||
return pd
|
||||
}
|
||||
|
||||
// WithLogger sets the logger for the PrivilegeDropper
|
||||
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||
return func(pd *PrivilegeDropper) {
|
||||
pd.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// log returns the logger, falling back to standard logger if none set
|
||||
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||
if pd.logger != nil {
|
||||
return pd.logger
|
||||
}
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
}
|
||||
|
||||
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
||||
@@ -108,7 +83,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex
|
||||
break
|
||||
}
|
||||
}
|
||||
pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
||||
}
|
||||
|
||||
@@ -231,22 +206,17 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config
|
||||
|
||||
var execCmd *exec.Cmd
|
||||
if config.Command == "" {
|
||||
execCmd = exec.CommandContext(ctx, config.Shell)
|
||||
} else {
|
||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||
os.Exit(ExitCodeSuccess)
|
||||
}
|
||||
execCmd.Args[0] = "-" + filepath.Base(config.Shell)
|
||||
|
||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||
execCmd.Stdin = os.Stdin
|
||||
execCmd.Stdout = os.Stdout
|
||||
execCmd.Stderr = os.Stderr
|
||||
|
||||
if config.Command == "" {
|
||||
log.Tracef("executing login shell: %s", execCmd.Path)
|
||||
} else {
|
||||
cmdParts := strings.Fields(config.Command)
|
||||
safeCmd := safeLogCommand(cmdParts)
|
||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||
}
|
||||
cmdParts := strings.Fields(config.Command)
|
||||
safeCmd := safeLogCommand(cmdParts)
|
||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||
if err := execCmd.Run(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
|
||||
@@ -28,45 +28,22 @@ const (
|
||||
)
|
||||
|
||||
type WindowsExecutorConfig struct {
|
||||
Username string
|
||||
Domain string
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
Args []string
|
||||
Pty bool
|
||||
PtyWidth int
|
||||
PtyHeight int
|
||||
Username string
|
||||
Domain string
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
Args []string
|
||||
Interactive bool
|
||||
Pty bool
|
||||
PtyWidth int
|
||||
PtyHeight int
|
||||
}
|
||||
|
||||
type PrivilegeDropper struct {
|
||||
logger *log.Entry
|
||||
}
|
||||
type PrivilegeDropper struct{}
|
||||
|
||||
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||
|
||||
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||
pd := &PrivilegeDropper{}
|
||||
for _, opt := range opts {
|
||||
opt(pd)
|
||||
}
|
||||
return pd
|
||||
}
|
||||
|
||||
// WithLogger sets the logger for the PrivilegeDropper
|
||||
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||
return func(pd *PrivilegeDropper) {
|
||||
pd.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// log returns the logger, falling back to standard logger if none set
|
||||
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||
if pd.logger != nil {
|
||||
return pd.logger
|
||||
}
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -79,6 +56,7 @@ const (
|
||||
|
||||
// Common error messages
|
||||
commandFlag = "-Command"
|
||||
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
|
||||
convertUsernameError = "convert username to UTF16: %w"
|
||||
convertDomainError = "convert domain to UTF16: %w"
|
||||
)
|
||||
@@ -102,7 +80,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
|
||||
shellArgs = []string{shell}
|
||||
}
|
||||
|
||||
pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||
|
||||
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
||||
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
||||
@@ -202,10 +180,10 @@ func newLsaString(s string) lsaString {
|
||||
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
||||
userCpn := buildUserCpn(username, domain)
|
||||
|
||||
pd := NewPrivilegeDropper(WithLogger(logger))
|
||||
pd := NewPrivilegeDropper()
|
||||
isDomainUser := !pd.isLocalUser(domain)
|
||||
|
||||
lsaHandle, err := initializeLsaConnection()
|
||||
@@ -219,12 +197,12 @@ func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.H
|
||||
return 0, err
|
||||
}
|
||||
|
||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser)
|
||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||
}
|
||||
|
||||
// buildUserCpn constructs the user principal name
|
||||
@@ -332,21 +310,21 @@ func lookupPrincipalName(username, domain string) (string, error) {
|
||||
}
|
||||
|
||||
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
||||
func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||
if isDomainUser {
|
||||
return prepareDomainS4ULogon(logger, username, domain)
|
||||
return prepareDomainS4ULogon(username, domain)
|
||||
}
|
||||
return prepareLocalS4ULogon(logger, username)
|
||||
return prepareLocalS4ULogon(username)
|
||||
}
|
||||
|
||||
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
||||
func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||
upn, err := lookupPrincipalName(username, domain)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
||||
}
|
||||
|
||||
logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||
|
||||
upnUtf16, err := windows.UTF16FromString(upn)
|
||||
if err != nil {
|
||||
@@ -379,8 +357,8 @@ func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.P
|
||||
}
|
||||
|
||||
// prepareLocalS4ULogon creates S4U logon structure for local users
|
||||
func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) {
|
||||
logger.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
||||
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||
|
||||
usernameUtf16, err := windows.UTF16FromString(username)
|
||||
if err != nil {
|
||||
@@ -428,11 +406,11 @@ func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, u
|
||||
}
|
||||
|
||||
// performS4ULogon executes the S4U logon operation
|
||||
func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||
var tokenSource tokenSource
|
||||
copy(tokenSource.SourceName[:], "netbird")
|
||||
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
||||
logger.Debugf("AllocateLocallyUniqueId failed")
|
||||
log.Debugf("AllocateLocallyUniqueId failed")
|
||||
}
|
||||
|
||||
originName := newLsaString("netbird")
|
||||
@@ -463,7 +441,7 @@ func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId
|
||||
|
||||
if profile != 0 {
|
||||
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
||||
logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -471,7 +449,7 @@ func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId
|
||||
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
||||
}
|
||||
|
||||
logger.Debugf("created S4U %s token for user %s",
|
||||
log.Debugf("created S4U %s token for user %s",
|
||||
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
||||
return token, nil
|
||||
}
|
||||
@@ -519,8 +497,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
|
||||
|
||||
// authenticateLocalUser handles authentication for local users
|
||||
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
||||
pd.log().Debugf("using S4U authentication for local user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(pd.log(), username, ".")
|
||||
log.Debugf("using S4U authentication for local user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, ".")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
||||
}
|
||||
@@ -529,12 +507,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string)
|
||||
|
||||
// authenticateDomainUser handles authentication for domain users
|
||||
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
||||
pd.log().Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(pd.log(), username, domain)
|
||||
log.Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
||||
}
|
||||
pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername)
|
||||
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
@@ -548,7 +526,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec
|
||||
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(token); err != nil {
|
||||
pd.log().Debugf("close impersonation token: %v", err)
|
||||
log.Debugf("close impersonation token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -586,7 +564,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
|
||||
return cmd, primaryToken, nil
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su - for privilege switching (Windows stub).
|
||||
func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
|
||||
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||
return nil, fmt.Errorf("su command not available on Windows")
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -88,7 +88,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
serverNoJWT.SetAllowRootLogin(true)
|
||||
|
||||
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
||||
defer func() { require.NoError(t, serverNoJWT.Stop()) }()
|
||||
defer require.NoError(t, serverNoJWT.Stop())
|
||||
|
||||
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
||||
require.NoError(t, err)
|
||||
@@ -213,7 +213,7 @@ func TestJWTDetection(t *testing.T) {
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -341,7 +341,7 @@ func TestJWTFailClose(t *testing.T) {
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -596,7 +596,7 @@ func TestJWTAuthentication(t *testing.T) {
|
||||
server.UpdateSSHAuth(authConfig)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
@@ -715,7 +715,7 @@ func TestJWTMultipleAudiences(t *testing.T) {
|
||||
server.UpdateSSHAuth(authConfig)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() { require.NoError(t, server.Stop()) }()
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -271,6 +271,13 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
|
||||
return s.allowRemotePortForwarding
|
||||
}
|
||||
|
||||
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
|
||||
func (s *Server) isPortForwardingEnabled() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
|
||||
}
|
||||
|
||||
// parseTcpipForwardRequest parses the SSH request payload
|
||||
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
||||
var payload tcpipForwardMsg
|
||||
|
||||
@@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
||||
sessions = append(sessions, info)
|
||||
}
|
||||
|
||||
// Add authenticated connections without sessions (e.g., -N or port-forwarding only)
|
||||
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
|
||||
for key, connState := range s.connections {
|
||||
remoteAddr := string(key)
|
||||
if reportedAddrs[remoteAddr] {
|
||||
|
||||
@@ -483,11 +483,12 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_NonPtyShellSession(t *testing.T) {
|
||||
// Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings.
|
||||
func TestServer_PortForwardingOnlySession(t *testing.T) {
|
||||
// Test that sessions without PTY and command are allowed when port forwarding is enabled
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -495,26 +496,36 @@ func TestServer_NonPtyShellSession(t *testing.T) {
|
||||
name string
|
||||
allowLocalForwarding bool
|
||||
allowRemoteForwarding bool
|
||||
expectAllowed bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "shell_with_local_forwarding_enabled",
|
||||
name: "session_allowed_with_local_forwarding",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: false,
|
||||
expectAllowed: true,
|
||||
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
|
||||
},
|
||||
{
|
||||
name: "shell_with_remote_forwarding_enabled",
|
||||
name: "session_allowed_with_remote_forwarding",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: true,
|
||||
expectAllowed: true,
|
||||
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
|
||||
},
|
||||
{
|
||||
name: "shell_with_both_forwarding_enabled",
|
||||
name: "session_allowed_with_both",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: true,
|
||||
expectAllowed: true,
|
||||
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
|
||||
},
|
||||
{
|
||||
name: "shell_with_forwarding_disabled",
|
||||
name: "session_denied_without_forwarding",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: false,
|
||||
expectAllowed: false,
|
||||
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -534,6 +545,7 @@ func TestServer_NonPtyShellSession(t *testing.T) {
|
||||
_ = server.Stop()
|
||||
}()
|
||||
|
||||
// Connect to the server without requesting PTY or command
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -545,10 +557,20 @@ func TestServer_NonPtyShellSession(t *testing.T) {
|
||||
_ = client.Close()
|
||||
}()
|
||||
|
||||
// Execute without PTY and no command - simulates ssh -T (shell without PTY)
|
||||
// Should always succeed regardless of port forwarding settings
|
||||
_, err = client.ExecuteCommand(ctx, "")
|
||||
assert.NoError(t, err, "Non-PTY shell session should be allowed")
|
||||
// Execute a command without PTY - this simulates ssh -T with no command
|
||||
// The server should either allow it (port forwarding enabled) or reject it
|
||||
output, err := client.ExecuteCommand(ctx, "")
|
||||
if tt.expectAllowed {
|
||||
// When allowed, the session stays open until cancelled
|
||||
// ExecuteCommand with empty command should return without error
|
||||
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
|
||||
assert.NotContains(t, output, "port forwarding is disabled",
|
||||
"Output should not contain port forwarding disabled message")
|
||||
} else if err != nil {
|
||||
// When denied, we expect an error message about port forwarding being disabled
|
||||
assert.Contains(t, err.Error(), "port forwarding is disabled",
|
||||
"Should get port forwarding disabled message")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,14 +405,12 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) {
|
||||
assert.Equal(t, "-Command", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
} else {
|
||||
// Test Unix shell behavior
|
||||
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
||||
assert.Equal(t, "/bin/sh", args[0])
|
||||
assert.Equal(t, "-c", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
|
||||
args = server.getShellCommandArgs("/bin/sh", "")
|
||||
assert.Equal(t, "/bin/sh", args[0])
|
||||
assert.Len(t, args, 1)
|
||||
assert.Equal(t, "-l", args[1])
|
||||
assert.Equal(t, "-c", args[2])
|
||||
assert.Equal(t, "echo test", args[3])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -62,12 +62,54 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
|
||||
if isPty && !hasCommand {
|
||||
// ssh <host> - PTY interactive session (login)
|
||||
s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh)
|
||||
} else {
|
||||
// ssh <host> <cmd>, ssh -t <host> <cmd>, ssh -T <host> - command or shell execution
|
||||
s.handleExecution(logger, session, privilegeResult, ptyReq, winCh)
|
||||
switch {
|
||||
case isPty && hasCommand:
|
||||
// ssh -t <host> <cmd> - Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, winCh)
|
||||
case isPty:
|
||||
// ssh <host> - Pty interactive session (login)
|
||||
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
|
||||
case hasCommand:
|
||||
// ssh <host> <cmd> - non-Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, nil)
|
||||
default:
|
||||
// ssh -T (or ssh -N) - no PTY, no command
|
||||
s.handleNonInteractiveSession(logger, session)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonInteractiveSession handles sessions that have no PTY and no command.
|
||||
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
|
||||
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
|
||||
s.updateSessionType(session, cmdNonInteractive)
|
||||
|
||||
if !s.isPortForwardingEnabled() {
|
||||
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
logger.Infof("rejected non-interactive session: port forwarding disabled")
|
||||
return
|
||||
}
|
||||
|
||||
<-session.Context().Done()
|
||||
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, state := range s.sessions {
|
||||
if state.session == session {
|
||||
state.sessionType = sessionType
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handlePtyLogin is not supported on JS/WASM
|
||||
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
// handlePty is not supported on JS/WASM
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
||||
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
|
||||
@@ -8,18 +8,19 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// StartTestServer starts the SSH server and returns the address it's listening on.
|
||||
func StartTestServer(t *testing.T, server *Server) string {
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
// Use port 0 to let the OS assign a free port
|
||||
addrPort := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Get the actual listening address from the server
|
||||
actualAddr := server.Addr()
|
||||
if actualAddr == nil {
|
||||
errChan <- fmt.Errorf("server started but no listener address available")
|
||||
|
||||
@@ -181,8 +181,8 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
|
||||
|
||||
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
|
||||
// Returns the command and a cleanup function (no-op on Unix).
|
||||
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
logger.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
if err := validateUsername(localUser.Username); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||
@@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
|
||||
}
|
||||
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
config := ExecutorConfig{
|
||||
UID: uid,
|
||||
GID: gid,
|
||||
@@ -233,7 +233,7 @@ func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.Use
|
||||
shell := getUserShell(localUser.Uid)
|
||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||
|
||||
cmd := s.createShellCommand(session.Context(), shell, args)
|
||||
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
||||
|
||||
|
||||
@@ -88,20 +88,20 @@ func validateUsernameFormat(username string) error {
|
||||
|
||||
// createExecutorCommand creates a command using Windows executor for privilege dropping.
|
||||
// Returns the command and a cleanup function that must be called after starting the process.
|
||||
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
logger.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
username, _ := s.parseUsername(localUser.Username)
|
||||
if err := validateUsername(username); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
|
||||
}
|
||||
|
||||
return s.createUserSwitchCommand(logger, session, localUser)
|
||||
return s.createUserSwitchCommand(localUser, session, hasPty)
|
||||
}
|
||||
|
||||
// createUserSwitchCommand creates a command with Windows user switching.
|
||||
// Returns the command and a cleanup function that must be called after starting the process.
|
||||
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
|
||||
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
@@ -113,14 +113,15 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session,
|
||||
}
|
||||
|
||||
config := WindowsExecutorConfig{
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Interactive: interactive || (rawCmd == ""),
|
||||
}
|
||||
|
||||
dropper := NewPrivilegeDropper(WithLogger(logger))
|
||||
dropper := NewPrivilegeDropper()
|
||||
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -129,7 +130,7 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session,
|
||||
cleanup := func() {
|
||||
if token != 0 {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
logger.Debugf("close primary token: %v", err)
|
||||
log.Debugf("close primary token: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ var (
|
||||
)
|
||||
|
||||
// ExecutePtyWithUserToken executes a command with ConPty using user token.
|
||||
func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
||||
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
||||
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
|
||||
commandLine := buildCommandLine(args)
|
||||
|
||||
@@ -64,7 +64,7 @@ func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfi
|
||||
Pty: ptyConfig,
|
||||
User: userConfig,
|
||||
Session: session,
|
||||
Context: session.Context(),
|
||||
Context: ctx,
|
||||
}
|
||||
|
||||
return executeConPtyWithConfig(commandLine, config)
|
||||
|
||||
@@ -63,8 +63,6 @@ func (h *eventHandler) listen(ctx context.Context) {
|
||||
h.handleNetworksClick()
|
||||
case <-h.client.mNotifications.ClickedCh:
|
||||
h.handleNotificationsClick()
|
||||
case <-systray.TrayOpenedCh:
|
||||
h.client.updateExitNodes()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,6 +341,7 @@ func (s *serviceClient) updateExitNodes() {
|
||||
log.Errorf("get client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
exitNodes, err := s.getExitNodes(conn)
|
||||
if err != nil {
|
||||
log.Errorf("get exit nodes: %v", err)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -31,7 +31,7 @@ require (
|
||||
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.7.0
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
||||
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
||||
github.com/awnumar/memguard v0.23.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.36.3
|
||||
|
||||
4
go.sum
4
go.sum
@@ -13,8 +13,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ=
|
||||
fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE=
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4=
|
||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
|
||||
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 h1:eA5/u2XRd8OUkoMqEv3IBlFYSruNlXD8bRHDiqm0VNI=
|
||||
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
|
||||
|
||||
@@ -232,9 +232,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
userID, err := s.accountManager.GetUserIDByPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
s.syncSem.Add(-1)
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
return status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -30,12 +30,6 @@ type Manager interface {
|
||||
autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
|
||||
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
|
||||
AcceptUserInvite(ctx context.Context, token, password string) error
|
||||
RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
|
||||
GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error)
|
||||
ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
|
||||
DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error
|
||||
|
||||
@@ -199,11 +199,6 @@ const (
|
||||
|
||||
UserPasswordChanged Activity = 103
|
||||
|
||||
UserInviteLinkCreated Activity = 104
|
||||
UserInviteLinkAccepted Activity = 105
|
||||
UserInviteLinkRegenerated Activity = 106
|
||||
UserInviteLinkDeleted Activity = 107
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -332,11 +327,6 @@ var activityMap = map[Activity]Code{
|
||||
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
|
||||
|
||||
UserPasswordChanged: {"User password changed", "user.password.change"},
|
||||
|
||||
UserInviteLinkCreated: {"User invite link created", "user.invite.link.create"},
|
||||
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
|
||||
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
|
||||
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -68,13 +68,6 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
if err := bypass.AddBypassPath("/api/setup"); err != nil {
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
// Public invite endpoints (tokens start with nbi_)
|
||||
if err := bypass.AddBypassPath("/api/users/invites/nbi_*"); err != nil {
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
|
||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||
@@ -139,8 +132,6 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController)
|
||||
users.AddEndpoints(accountManager, router)
|
||||
users.AddInvitesEndpoints(accountManager, router)
|
||||
users.AddPublicInvitesEndpoints(accountManager, router)
|
||||
setup_keys.AddEndpoints(accountManager, router)
|
||||
policies.AddEndpoints(accountManager, LocationManager, router)
|
||||
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
|
||||
@@ -154,7 +145,6 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router)
|
||||
|
||||
// Mount embedded IdP handler at /oauth2 path if configured
|
||||
if embeddedIdpEnabled {
|
||||
|
||||
@@ -28,15 +28,6 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// AddVersionEndpoint registers the authenticated version endpoint.
|
||||
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
instanceManager: instanceManager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// getInstanceStatus returns the instance status including whether setup is required.
|
||||
// This endpoint is unauthenticated.
|
||||
func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -74,29 +65,3 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
|
||||
Email: userData.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// getVersionInfo returns version information for NetBird components.
|
||||
// This endpoint requires authentication.
|
||||
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
|
||||
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)
|
||||
util.WriteErrorResponse("failed to get version info", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.InstanceVersionInfo{
|
||||
ManagementCurrentVersion: versionInfo.CurrentVersion,
|
||||
ManagementUpdateAvailable: versionInfo.ManagementUpdateAvailable,
|
||||
}
|
||||
|
||||
if versionInfo.DashboardVersion != "" {
|
||||
resp.DashboardAvailableVersion = &versionInfo.DashboardVersion
|
||||
}
|
||||
|
||||
if versionInfo.ManagementVersion != "" {
|
||||
resp.ManagementAvailableVersion = &versionInfo.ManagementVersion
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ type mockInstanceManager struct {
|
||||
isSetupRequired bool
|
||||
isSetupRequiredFn func(ctx context.Context) (bool, error)
|
||||
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) {
|
||||
@@ -67,18 +66,6 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
|
||||
if m.getVersionInfoFn != nil {
|
||||
return m.getVersionInfoFn(ctx)
|
||||
}
|
||||
return &nbinstance.VersionInfo{
|
||||
CurrentVersion: "0.34.0",
|
||||
DashboardVersion: "2.0.0",
|
||||
ManagementVersion: "0.35.0",
|
||||
ManagementUpdateAvailable: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
|
||||
|
||||
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
|
||||
@@ -292,44 +279,3 @@ func TestSetup_ManagerError(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
func TestGetVersionInfo_Success(t *testing.T) {
|
||||
manager := &mockInstanceManager{}
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response api.InstanceVersionInfo
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "0.34.0", response.ManagementCurrentVersion)
|
||||
assert.NotNil(t, response.DashboardAvailableVersion)
|
||||
assert.Equal(t, "2.0.0", *response.DashboardAvailableVersion)
|
||||
assert.NotNil(t, response.ManagementAvailableVersion)
|
||||
assert.Equal(t, "0.35.0", *response.ManagementAvailableVersion)
|
||||
assert.True(t, response.ManagementUpdateAvailable)
|
||||
}
|
||||
|
||||
func TestGetVersionInfo_Error(t *testing.T) {
|
||||
manager := &mockInstanceManager{
|
||||
getVersionInfoFn: func(ctx context.Context) (*nbinstance.VersionInfo, error) {
|
||||
return nil, errors.New("failed to fetch versions")
|
||||
},
|
||||
}
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
@@ -1,263 +0,0 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// publicInviteRateLimiter limits public invite requests by IP address to prevent brute-force attacks
|
||||
var publicInviteRateLimiter = middleware.NewAPIRateLimiter(&middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: 10, // 10 attempts per minute per IP
|
||||
Burst: 5, // Allow burst of 5 requests
|
||||
CleanupInterval: 10 * time.Minute,
|
||||
LimiterTTL: 30 * time.Minute,
|
||||
})
|
||||
|
||||
// toUserInviteResponse converts a UserInvite to an API response.
|
||||
func toUserInviteResponse(invite *types.UserInvite) api.UserInvite {
|
||||
autoGroups := invite.UserInfo.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
}
|
||||
var inviteLink *string
|
||||
if invite.InviteToken != "" {
|
||||
inviteLink = &invite.InviteToken
|
||||
}
|
||||
return api.UserInvite{
|
||||
Id: invite.UserInfo.ID,
|
||||
Email: invite.UserInfo.Email,
|
||||
Name: invite.UserInfo.Name,
|
||||
Role: invite.UserInfo.Role,
|
||||
AutoGroups: autoGroups,
|
||||
ExpiresAt: invite.InviteExpiresAt.UTC(),
|
||||
CreatedAt: invite.InviteCreatedAt.UTC(),
|
||||
Expired: time.Now().After(invite.InviteExpiresAt),
|
||||
InviteToken: inviteLink,
|
||||
}
|
||||
}
|
||||
|
||||
// invitesHandler handles user invite operations
|
||||
type invitesHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
// AddInvitesEndpoints registers invite-related endpoints
|
||||
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
h := &invitesHandler{accountManager: accountManager}
|
||||
|
||||
// Authenticated endpoints (require admin)
|
||||
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
|
||||
func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
h := &invitesHandler{accountManager: accountManager}
|
||||
|
||||
// Create a subrouter for public invite endpoints with rate limiting middleware
|
||||
publicRouter := router.PathPrefix("/users/invites").Subrouter()
|
||||
publicRouter.Use(publicInviteRateLimiter.Middleware)
|
||||
|
||||
// Public endpoints (no auth required, protected by token and rate limited)
|
||||
publicRouter.HandleFunc("/{token}", h.getInviteInfo).Methods("GET", "OPTIONS")
|
||||
publicRouter.HandleFunc("/{token}/accept", h.acceptInvite).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// listInvites handles GET /api/users/invites
|
||||
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]api.UserInvite, 0, len(invites))
|
||||
for _, invite := range invites {
|
||||
resp = append(resp, toUserInviteResponse(invite))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
// createInvite handles POST /api/users/invites
|
||||
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.UserInviteCreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
invite := &types.UserInfo{
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
Role: req.Role,
|
||||
AutoGroups: req.AutoGroups,
|
||||
}
|
||||
|
||||
expiresIn := 0
|
||||
if req.ExpiresIn != nil {
|
||||
expiresIn = *req.ExpiresIn
|
||||
}
|
||||
|
||||
result, err := h.accountManager.CreateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, invite, expiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
result.InviteCreatedAt = time.Now().UTC()
|
||||
resp := toUserInviteResponse(result)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
// getInviteInfo handles GET /api/users/invites/{token}
|
||||
func (h *invitesHandler) getInviteInfo(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
vars := mux.Vars(r)
|
||||
token := vars["token"]
|
||||
if token == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := h.accountManager.GetUserInviteInfo(r.Context(), token)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := info.ExpiresAt.UTC()
|
||||
util.WriteJSONObject(r.Context(), w, &api.UserInviteInfo{
|
||||
Email: info.Email,
|
||||
Name: info.Name,
|
||||
ExpiresAt: expiresAt,
|
||||
Valid: info.Valid,
|
||||
InvitedBy: info.InvitedBy,
|
||||
})
|
||||
}
|
||||
|
||||
// acceptInvite handles POST /api/users/invites/{token}/accept
|
||||
func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
vars := mux.Vars(r)
|
||||
token := vars["token"]
|
||||
if token == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.UserInviteAcceptRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.AcceptUserInvite(r.Context(), token, req.Password)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &api.UserInviteAcceptResponse{Success: true})
|
||||
}
|
||||
|
||||
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
|
||||
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.UserInviteRegenerateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
// Allow empty body (io.EOF) - expiresIn is optional
|
||||
if !errors.Is(err, io.EOF) {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
expiresIn := 0
|
||||
if req.ExpiresIn != nil {
|
||||
expiresIn = *req.ExpiresIn
|
||||
}
|
||||
|
||||
result, err := h.accountManager.RegenerateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID, expiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := result.InviteExpiresAt.UTC()
|
||||
util.WriteJSONObject(r.Context(), w, &api.UserInviteRegenerateResponse{
|
||||
InviteToken: result.InviteToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
// deleteInvite handles DELETE /api/users/invites/{inviteId}
|
||||
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
@@ -1,642 +0,0 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "test-account-id"
|
||||
testUserID = "test-user-id"
|
||||
testInviteID = "test-invite-id"
|
||||
testInviteToken = "nbi_testtoken123456789012345678"
|
||||
testEmail = "invite@example.com"
|
||||
testName = "Test User"
|
||||
)
|
||||
|
||||
func setupInvitesTestHandler(am *mock_server.MockAccountManager) *invitesHandler {
|
||||
return &invitesHandler{
|
||||
accountManager: am,
|
||||
}
|
||||
}
|
||||
|
||||
func TestListInvites(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
testInvites := []*types.UserInvite{
|
||||
{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "invite-1",
|
||||
Email: "user1@example.com",
|
||||
Name: "User One",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1"},
|
||||
},
|
||||
InviteExpiresAt: now.Add(24 * time.Hour),
|
||||
InviteCreatedAt: now,
|
||||
},
|
||||
{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "invite-2",
|
||||
Email: "user2@example.com",
|
||||
Name: "User Two",
|
||||
Role: "admin",
|
||||
AutoGroups: nil,
|
||||
},
|
||||
InviteExpiresAt: now.Add(-1 * time.Hour), // Expired
|
||||
InviteCreatedAt: now.Add(-48 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "successful list",
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
return testInvites, nil
|
||||
},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "empty list",
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
return []*types.UserInvite{}, nil
|
||||
},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
},
|
||||
expectedCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
ListUserInvitesFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/users/invites", nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.listInvites(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp []api.UserInvite
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp, tc.expectedCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInvite(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(72 * time.Hour)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
requestBody string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
|
||||
}{
|
||||
{
|
||||
name: "successful create",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":["group-1"]}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: testInviteID,
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
Role: invite.Role,
|
||||
AutoGroups: invite.AutoGroups,
|
||||
Status: string(types.UserStatusInvited),
|
||||
},
|
||||
InviteToken: testInviteToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful create with custom expiration",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"admin","auto_groups":[],"expires_in":3600}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
assert.Equal(t, 3600, expiresIn)
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: testInviteID,
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
Role: invite.Role,
|
||||
AutoGroups: []string{},
|
||||
Status: string(types.UserStatusInvited),
|
||||
},
|
||||
InviteToken: testInviteToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user already exists",
|
||||
requestBody: `{"email":"existing@example.com","name":"Existing User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusConflict,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite already exists",
|
||||
requestBody: `{"email":"invited@example.com","name":"Invited User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusConflict,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embedded IDP not enabled",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
requestBody: `{invalid json}`,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
CreateUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users/invites", bytes.NewBufferString(tc.requestBody))
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.createInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInvite
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testInviteID, resp.Id)
|
||||
assert.NotNil(t, resp.InviteToken)
|
||||
assert.NotEmpty(t, *resp.InviteToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInviteInfo(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
token string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
|
||||
}{
|
||||
{
|
||||
name: "successful get valid invite",
|
||||
token: testInviteToken,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return &types.UserInviteInfo{
|
||||
Email: testEmail,
|
||||
Name: testName,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
Valid: true,
|
||||
InvitedBy: "Admin User",
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful get expired invite",
|
||||
token: testInviteToken,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return &types.UserInviteInfo{
|
||||
Email: testEmail,
|
||||
Name: testName,
|
||||
ExpiresAt: now.Add(-24 * time.Hour),
|
||||
Valid: false,
|
||||
InvitedBy: "Admin User",
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
token: "nbi_invalidtoken1234567890123456",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return nil, status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid token format",
|
||||
token: "invalid",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid invite token")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing token",
|
||||
token: "",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
GetUserInviteInfoFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/users/invites/"+tc.token, nil)
|
||||
if tc.token != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.getInviteInfo(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInviteInfo
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testEmail, resp.Email)
|
||||
assert.Equal(t, testName, resp.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcceptInvite(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
token string
|
||||
requestBody string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, token, password string) error
|
||||
}{
|
||||
{
|
||||
name: "successful accept",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
token: "nbi_invalidtoken1234567890123456",
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite expired",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "invite has expired")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embedded IDP not enabled",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing token",
|
||||
token: "",
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
token: testInviteToken,
|
||||
requestBody: `{invalid}`,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
mockFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "password too short",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"Short1!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters long")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "password missing digit",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"NoDigitPass!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must contain at least one digit")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "password missing uppercase",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"nouppercase1!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must contain at least one uppercase letter")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "password missing special character",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"NoSpecial123"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must contain at least one special character")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
AcceptUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.token+"/accept", bytes.NewBufferString(tc.requestBody))
|
||||
if tc.token != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.acceptInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInviteAcceptResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.Success)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegenerateInvite(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(72 * time.Hour)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
inviteID string
|
||||
requestBody string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
|
||||
}{
|
||||
{
|
||||
name: "successful regenerate with empty body",
|
||||
inviteID: testInviteID,
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
assert.Equal(t, 0, expiresIn)
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: inviteID,
|
||||
Email: testEmail,
|
||||
},
|
||||
InviteToken: "nbi_newtoken12345678901234567890",
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful regenerate with custom expiration",
|
||||
inviteID: testInviteID,
|
||||
requestBody: `{"expires_in":7200}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
assert.Equal(t, 7200, expiresIn)
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: inviteID,
|
||||
Email: testEmail,
|
||||
},
|
||||
InviteToken: "nbi_newtoken12345678901234567890",
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
inviteID: "non-existent-invite",
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
inviteID: testInviteID,
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing invite ID",
|
||||
inviteID: "",
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON should return error",
|
||||
inviteID: testInviteID,
|
||||
requestBody: `{invalid json}`,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
RegenerateUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
var body io.Reader
|
||||
if tc.requestBody != "" {
|
||||
body = bytes.NewBufferString(tc.requestBody)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.inviteID+"/regenerate", body)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
if tc.inviteID != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.regenerateInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInviteRegenerateResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.InviteToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteInvite(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
inviteID string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
}{
|
||||
{
|
||||
name: "successful delete",
|
||||
inviteID: testInviteID,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
inviteID: "non-existent-invite",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
inviteID: testInviteID,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return status.NewPermissionDeniedError()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embedded IDP not enabled",
|
||||
inviteID: testInviteID,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing invite ID",
|
||||
inviteID: "",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
DeleteUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/users/invites/"+tc.inviteID, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
if tc.inviteID != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.deleteInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,14 +2,10 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
@@ -148,25 +144,3 @@ func (rl *APIRateLimiter) Reset(key string) {
|
||||
defer rl.mu.Unlock()
|
||||
delete(rl.limiters, key)
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that rate limits requests by client IP.
|
||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := getClientIP(r)
|
||||
if !rl.Allow(clientIP) {
|
||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP extracts the client IP address from the request.
|
||||
func getClientIP(r *http.Request) string {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAPIRateLimiter_Allow(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60, // 1 per second
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// First two requests should be allowed (burst)
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
|
||||
// Third request should be denied (exceeded burst)
|
||||
assert.False(t, rl.Allow("test-key"))
|
||||
|
||||
// Different key should be allowed
|
||||
assert.True(t, rl.Allow("different-key"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Middleware(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60, // 1 per second
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// Create a simple handler that returns 200 OK
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Wrap with rate limiter middleware
|
||||
handler := rl.Middleware(nextHandler)
|
||||
|
||||
// First two requests should pass (burst)
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1)
|
||||
}
|
||||
|
||||
// Third request should be rate limited
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
handler := rl.Middleware(nextHandler)
|
||||
|
||||
// Request from first IP
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.1:12345"
|
||||
rr1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr1, req1)
|
||||
assert.Equal(t, http.StatusOK, rr1.Code)
|
||||
|
||||
// Second request from first IP should be rate limited
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.1:12345"
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rr2.Code)
|
||||
|
||||
// Request from different IP should be allowed
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req3.RemoteAddr = "192.168.1.2:12345"
|
||||
rr3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr3, req3)
|
||||
assert.Equal(t, http.StatusOK, rr3.Code)
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "remote addr with port",
|
||||
remoteAddr: "192.168.1.1:12345",
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "remote addr without port",
|
||||
remoteAddr: "192.168.1.1",
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 with port",
|
||||
remoteAddr: "[::1]:12345",
|
||||
expected: "::1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 without port",
|
||||
remoteAddr: "::1",
|
||||
expected: "::1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = tc.remoteAddr
|
||||
assert.Equal(t, tc.expected, getClientIP(req))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Reset(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// Use up the burst
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
assert.False(t, rl.Allow("test-key"))
|
||||
|
||||
// Reset the limiter
|
||||
rl.Reset("test-key")
|
||||
|
||||
// Should be allowed again
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
}
|
||||
@@ -20,7 +20,7 @@ const (
|
||||
staticClientCLI = "netbird-cli"
|
||||
defaultCLIRedirectURL1 = "http://localhost:53000/"
|
||||
defaultCLIRedirectURL2 = "http://localhost:54000/"
|
||||
defaultScopes = "openid profile email groups"
|
||||
defaultScopes = "openid profile email"
|
||||
defaultUserIDClaim = "sub"
|
||||
)
|
||||
|
||||
|
||||
@@ -2,54 +2,18 @@ package instance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
// Version endpoints
|
||||
managementVersionURL = "https://pkgs.netbird.io/releases/latest/version"
|
||||
dashboardReleasesURL = "https://api.github.com/repos/netbirdio/dashboard/releases/latest"
|
||||
|
||||
// Cache TTL for version information
|
||||
versionCacheTTL = 60 * time.Minute
|
||||
|
||||
// HTTP client timeout
|
||||
httpTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// VersionInfo contains version information for NetBird components
|
||||
type VersionInfo struct {
|
||||
// CurrentVersion is the running management server version
|
||||
CurrentVersion string
|
||||
// DashboardVersion is the latest available dashboard version from GitHub
|
||||
DashboardVersion string
|
||||
// ManagementVersion is the latest available management version from GitHub
|
||||
ManagementVersion string
|
||||
// ManagementUpdateAvailable indicates if a newer management version is available
|
||||
ManagementUpdateAvailable bool
|
||||
}
|
||||
|
||||
// githubRelease represents a GitHub release response
|
||||
type githubRelease struct {
|
||||
TagName string `json:"tag_name"`
|
||||
}
|
||||
|
||||
// Manager handles instance-level operations like initial setup.
|
||||
type Manager interface {
|
||||
// IsSetupRequired checks if instance setup is required.
|
||||
@@ -59,9 +23,6 @@ type Manager interface {
|
||||
// CreateOwnerUser creates the initial owner user in the embedded IDP.
|
||||
// This should only be called when IsSetupRequired returns true.
|
||||
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
|
||||
// GetVersionInfo returns version information for NetBird components.
|
||||
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
|
||||
}
|
||||
|
||||
// DefaultManager is the default implementation of Manager.
|
||||
@@ -71,12 +32,6 @@ type DefaultManager struct {
|
||||
|
||||
setupRequired bool
|
||||
setupMu sync.RWMutex
|
||||
|
||||
// Version caching
|
||||
httpClient *http.Client
|
||||
versionMu sync.RWMutex
|
||||
cachedVersions *VersionInfo
|
||||
lastVersionFetch time.Time
|
||||
}
|
||||
|
||||
// NewManager creates a new instance manager.
|
||||
@@ -88,9 +43,6 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager)
|
||||
store: store,
|
||||
embeddedIdpManager: embeddedIdp,
|
||||
setupRequired: false,
|
||||
httpClient: &http.Client{
|
||||
Timeout: httpTimeout,
|
||||
},
|
||||
}
|
||||
|
||||
if embeddedIdp != nil {
|
||||
@@ -182,130 +134,3 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetVersionInfo returns version information for NetBird components.
|
||||
func (m *DefaultManager) GetVersionInfo(ctx context.Context) (*VersionInfo, error) {
|
||||
m.versionMu.RLock()
|
||||
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
|
||||
cached := *m.cachedVersions
|
||||
m.versionMu.RUnlock()
|
||||
return &cached, nil
|
||||
}
|
||||
m.versionMu.RUnlock()
|
||||
|
||||
return m.fetchVersionInfo(ctx)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) fetchVersionInfo(ctx context.Context) (*VersionInfo, error) {
|
||||
m.versionMu.Lock()
|
||||
// Double-check after acquiring write lock
|
||||
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
|
||||
cached := *m.cachedVersions
|
||||
m.versionMu.Unlock()
|
||||
return &cached, nil
|
||||
}
|
||||
m.versionMu.Unlock()
|
||||
|
||||
info := &VersionInfo{
|
||||
CurrentVersion: version.NetbirdVersion(),
|
||||
}
|
||||
|
||||
// Fetch management version from pkgs.netbird.io (plain text)
|
||||
mgmtVersion, err := m.fetchPlainTextVersion(ctx, managementVersionURL)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to fetch management version: %v", err)
|
||||
} else {
|
||||
info.ManagementVersion = mgmtVersion
|
||||
info.ManagementUpdateAvailable = isNewerVersion(info.CurrentVersion, mgmtVersion)
|
||||
}
|
||||
|
||||
// Fetch dashboard version from GitHub
|
||||
dashVersion, err := m.fetchGitHubRelease(ctx, dashboardReleasesURL)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to fetch dashboard version from GitHub: %v", err)
|
||||
} else {
|
||||
info.DashboardVersion = dashVersion
|
||||
}
|
||||
|
||||
// Update cache
|
||||
m.versionMu.Lock()
|
||||
m.cachedVersions = info
|
||||
m.lastVersionFetch = time.Now()
|
||||
m.versionMu.Unlock()
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// isNewerVersion returns true if latestVersion is greater than currentVersion
|
||||
func isNewerVersion(currentVersion, latestVersion string) bool {
|
||||
current, err := goversion.NewVersion(currentVersion)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
latest, err := goversion.NewVersion(latestVersion)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return latest.GreaterThan(current)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) fetchPlainTextVersion(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 100))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(body)), nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) fetchGitHubRelease(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release githubRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return "", fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
// Remove 'v' prefix if present
|
||||
tag := release.TagName
|
||||
if len(tag) > 0 && tag[0] == 'v' {
|
||||
tag = tag[1:]
|
||||
}
|
||||
|
||||
return tag, nil
|
||||
}
|
||||
|
||||
@@ -1,285 +0,0 @@
|
||||
package instance
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockRoundTripper implements http.RoundTripper for testing
|
||||
type mockRoundTripper struct {
|
||||
callCount atomic.Int32
|
||||
managementVersion string
|
||||
dashboardVersion string
|
||||
}
|
||||
|
||||
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
m.callCount.Add(1)
|
||||
|
||||
var body string
|
||||
if strings.Contains(req.URL.String(), "pkgs.netbird.io") {
|
||||
// Plain text response for management version
|
||||
body = m.managementVersion
|
||||
} else if strings.Contains(req.URL.String(), "github.com") {
|
||||
// JSON response for dashboard version
|
||||
jsonResp, _ := json.Marshal(githubRelease{TagName: "v" + m.dashboardVersion})
|
||||
body = string(jsonResp)
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewBufferString(body)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestDefaultManager_GetVersionInfo_ReturnsCurrentVersion(t *testing.T) {
|
||||
mockTransport := &mockRoundTripper{
|
||||
managementVersion: "0.65.0",
|
||||
dashboardVersion: "2.10.0",
|
||||
}
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Transport: mockTransport},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
info, err := m.GetVersionInfo(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// CurrentVersion should always be set
|
||||
assert.NotEmpty(t, info.CurrentVersion)
|
||||
assert.Equal(t, "0.65.0", info.ManagementVersion)
|
||||
assert.Equal(t, "2.10.0", info.DashboardVersion)
|
||||
assert.Equal(t, int32(2), mockTransport.callCount.Load()) // 2 calls: management + dashboard
|
||||
}
|
||||
|
||||
func TestDefaultManager_GetVersionInfo_CachesResults(t *testing.T) {
|
||||
mockTransport := &mockRoundTripper{
|
||||
managementVersion: "0.65.0",
|
||||
dashboardVersion: "2.10.0",
|
||||
}
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Transport: mockTransport},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First call
|
||||
info1, err := m.GetVersionInfo(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, info1.CurrentVersion)
|
||||
assert.Equal(t, "0.65.0", info1.ManagementVersion)
|
||||
|
||||
initialCallCount := mockTransport.callCount.Load()
|
||||
|
||||
// Second call should use cache (no additional HTTP calls)
|
||||
info2, err := m.GetVersionInfo(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, info1.CurrentVersion, info2.CurrentVersion)
|
||||
assert.Equal(t, info1.ManagementVersion, info2.ManagementVersion)
|
||||
assert.Equal(t, info1.DashboardVersion, info2.DashboardVersion)
|
||||
|
||||
// Verify no additional HTTP calls were made (cache was used)
|
||||
assert.Equal(t, initialCallCount, mockTransport.callCount.Load())
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_ParsesTagName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tagName string
|
||||
expected string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "tag with v prefix",
|
||||
tagName: "v1.2.3",
|
||||
expected: "1.2.3",
|
||||
},
|
||||
{
|
||||
name: "tag without v prefix",
|
||||
tagName: "1.2.3",
|
||||
expected: "1.2.3",
|
||||
},
|
||||
{
|
||||
name: "tag with prerelease",
|
||||
tagName: "v2.0.0-beta.1",
|
||||
expected: "2.0.0-beta.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(githubRelease{TagName: tc.tagName})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
version, err := m.fetchGitHubRelease(context.Background(), server.URL)
|
||||
|
||||
if tc.shouldError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, version)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_HandlesErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
}{
|
||||
{
|
||||
name: "not found",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: `{"message": "Not Found"}`,
|
||||
},
|
||||
{
|
||||
name: "rate limited",
|
||||
statusCode: http.StatusForbidden,
|
||||
body: `{"message": "API rate limit exceeded"}`,
|
||||
},
|
||||
{
|
||||
name: "server error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
body: `{"message": "Internal Server Error"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tc.statusCode)
|
||||
_, _ = w.Write([]byte(tc.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_InvalidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{invalid json}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_ContextCancellation(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(1 * time.Second)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(githubRelease{TagName: "v1.0.0"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
_, err := m.fetchGitHubRelease(ctx, server.URL)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIsNewerVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentVersion string
|
||||
latestVersion string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "latest is newer - minor version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "0.65.0",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "latest is newer - patch version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "0.64.2",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "latest is newer - major version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "1.0.0",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "versions are equal",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "0.64.1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "current is newer - minor version",
|
||||
currentVersion: "0.65.0",
|
||||
latestVersion: "0.64.1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "current is newer - patch version",
|
||||
currentVersion: "0.64.2",
|
||||
latestVersion: "0.64.1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "development version",
|
||||
currentVersion: "development",
|
||||
latestVersion: "0.65.0",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid latest version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "invalid",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := isNewerVersion(tc.currentVersion, tc.latestVersion)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -139,12 +139,6 @@ type MockAccountManager struct {
|
||||
CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||
GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||
GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||
CreateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
|
||||
AcceptUserInviteFunc func(ctx context.Context, token, password string) error
|
||||
RegenerateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
|
||||
GetUserInviteInfoFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
|
||||
ListUserInvitesFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
|
||||
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
|
||||
@@ -719,48 +713,6 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
if am.CreateUserInviteFunc != nil {
|
||||
return am.CreateUserInviteFunc(ctx, accountID, initiatorUserID, invite, expiresIn)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error {
|
||||
if am.AcceptUserInviteFunc != nil {
|
||||
return am.AcceptUserInviteFunc(ctx, token, password)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method AcceptUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
if am.RegenerateUserInviteFunc != nil {
|
||||
return am.RegenerateUserInviteFunc(ctx, accountID, initiatorUserID, inviteID, expiresIn)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RegenerateUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
if am.GetUserInviteInfoFunc != nil {
|
||||
return am.GetUserInviteInfoFunc(ctx, token)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUserInviteInfo is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
if am.ListUserInvitesFunc != nil {
|
||||
return am.ListUserInvitesFunc(ctx, accountID, initiatorUserID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListUserInvites is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
if am.DeleteUserInviteFunc != nil {
|
||||
return am.DeleteUserInviteFunc(ctx, accountID, initiatorUserID, inviteID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
|
||||
if am.GetAccountIDFromUserAuthFunc != nil {
|
||||
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)
|
||||
|
||||
@@ -150,8 +150,26 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = time.Now().UTC()
|
||||
now := time.Now().UTC()
|
||||
newStatus.LastSeen = now
|
||||
newStatus.Connected = connected
|
||||
|
||||
if oldStatus.Connected == connected {
|
||||
log.WithContext(ctx).Warnf("peer %s status race: already connected=%t (lastSeen=%s, now=%s, ephemeral=%t)",
|
||||
peer.ID,
|
||||
connected,
|
||||
oldStatus.LastSeen.Format(time.RFC3339Nano),
|
||||
now.Format(time.RFC3339Nano),
|
||||
peer.Ephemeral)
|
||||
}
|
||||
|
||||
if !connected && oldStatus.Connected && peer.Ephemeral {
|
||||
log.WithContext(ctx).Tracef("ephemeral peer %s disconnecting (lastSeen=%s, now=%s)",
|
||||
peer.ID,
|
||||
oldStatus.LastSeen.Format(time.RFC3339Nano),
|
||||
now.Format(time.RFC3339Nano))
|
||||
}
|
||||
|
||||
// whenever peer got connected that means that it logged in successfully
|
||||
if newStatus.Connected {
|
||||
newStatus.LoginExpired = false
|
||||
|
||||
@@ -126,7 +126,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -815,130 +815,6 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// SaveUserInvite saves a user invite to the database
|
||||
func (s *SqlStore) SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error {
|
||||
inviteCopy := invite.Copy()
|
||||
if err := inviteCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt invite: %w", err)
|
||||
}
|
||||
|
||||
result := s.db.Save(inviteCopy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save user invite to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save user invite to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserInviteByID retrieves a user invite by its ID and account ID
|
||||
func (s *SqlStore) GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var invite types.UserInviteRecord
|
||||
result := tx.Where("account_id = ?", accountID).Take(&invite, idQueryCondition, inviteID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "user invite not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
|
||||
}
|
||||
|
||||
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt invite: %w", err)
|
||||
}
|
||||
|
||||
return &invite, nil
|
||||
}
|
||||
|
||||
// GetUserInviteByHashedToken retrieves a user invite by its hashed token
|
||||
func (s *SqlStore) GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var invite types.UserInviteRecord
|
||||
result := tx.Take(&invite, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "user invite not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
|
||||
}
|
||||
|
||||
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt invite: %w", err)
|
||||
}
|
||||
|
||||
return &invite, nil
|
||||
}
|
||||
|
||||
// GetUserInviteByEmail retrieves a user invite by account ID and email.
|
||||
// Since email is encrypted with random IVs, we fetch all invites for the account
|
||||
// and compare emails in memory after decryption.
|
||||
func (s *SqlStore) GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var invites []*types.UserInviteRecord
|
||||
result := tx.Find(&invites, "account_id = ?", accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
|
||||
}
|
||||
|
||||
for _, invite := range invites {
|
||||
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt invite: %w", err)
|
||||
}
|
||||
if strings.EqualFold(invite.Email, email) {
|
||||
return invite, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "user invite not found for email")
|
||||
}
|
||||
|
||||
// GetAccountUserInvites retrieves all user invites for an account
|
||||
func (s *SqlStore) GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var invites []*types.UserInviteRecord
|
||||
result := tx.Find(&invites, "account_id = ?", accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
|
||||
}
|
||||
|
||||
for _, invite := range invites {
|
||||
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt invite: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return invites, nil
|
||||
}
|
||||
|
||||
// DeleteUserInvite deletes a user invite by its ID
|
||||
func (s *SqlStore) DeleteUserInvite(ctx context.Context, inviteID string) error {
|
||||
result := s.db.Delete(&types.UserInviteRecord{}, idQueryCondition, inviteID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete user invite from store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete user invite from store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
@@ -4393,9 +4269,6 @@ func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingS
|
||||
Take(&userID, GetKeyQueryCondition(s), peerKey)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "peer not found: index lookup failed")
|
||||
}
|
||||
return "", status.Errorf(status.Internal, "failed to get user ID by peer key")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,520 +0,0 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func TestSqlStore_SaveUserInvite(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-1",
|
||||
AccountID: "account-1",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1", "group-2"},
|
||||
HashedToken: "hashed-token-123",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invite was saved
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
assert.Equal(t, invite.Email, retrieved.Email)
|
||||
assert.Equal(t, invite.Name, retrieved.Name)
|
||||
assert.Equal(t, invite.Role, retrieved.Role)
|
||||
assert.Equal(t, invite.AutoGroups, retrieved.AutoGroups)
|
||||
assert.Equal(t, invite.CreatedBy, retrieved.CreatedBy)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveUserInvite_Update(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-update",
|
||||
AccountID: "account-1",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1"},
|
||||
HashedToken: "hashed-token-123",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the invite with a new token
|
||||
invite.HashedToken = "new-hashed-token"
|
||||
invite.ExpiresAt = time.Now().Add(24 * time.Hour)
|
||||
|
||||
err = store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the update
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "new-hashed-token", retrieved.HashedToken)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserInviteByID(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-get-by-id",
|
||||
AccountID: "account-1",
|
||||
Email: "getbyid@example.com",
|
||||
Name: "Get By ID User",
|
||||
Role: "admin",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-get",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by ID - success
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
assert.Equal(t, invite.Email, retrieved.Email)
|
||||
|
||||
// Get by ID - wrong account
|
||||
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, "wrong-account", invite.ID)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Get by ID - not found
|
||||
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, "non-existent")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserInviteByHashedToken(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-get-by-token",
|
||||
AccountID: "account-1",
|
||||
Email: "getbytoken@example.com",
|
||||
Name: "Get By Token User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1"},
|
||||
HashedToken: "unique-hashed-token-456",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by hashed token - success
|
||||
retrieved, err := store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, invite.HashedToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
assert.Equal(t, invite.Email, retrieved.Email)
|
||||
|
||||
// Get by hashed token - not found
|
||||
_, err = store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, "non-existent-token")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserInviteByEmail(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-get-by-email",
|
||||
AccountID: "account-email-test",
|
||||
Email: "unique-email@example.com",
|
||||
Name: "Get By Email User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-email",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by email - success
|
||||
retrieved, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, invite.Email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
|
||||
// Get by email - case insensitive
|
||||
retrieved, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "UNIQUE-EMAIL@EXAMPLE.COM")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
|
||||
// Get by email - wrong account
|
||||
_, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, "wrong-account", invite.Email)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Get by email - not found
|
||||
_, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "nonexistent@example.com")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountUserInvites(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := "account-list-invites"
|
||||
|
||||
invites := []*types.UserInviteRecord{
|
||||
{
|
||||
ID: "invite-list-1",
|
||||
AccountID: accountID,
|
||||
Email: "user1@example.com",
|
||||
Name: "User One",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1"},
|
||||
HashedToken: "hashed-token-list-1",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
},
|
||||
{
|
||||
ID: "invite-list-2",
|
||||
AccountID: accountID,
|
||||
Email: "user2@example.com",
|
||||
Name: "User Two",
|
||||
Role: "admin",
|
||||
AutoGroups: []string{"group-2"},
|
||||
HashedToken: "hashed-token-list-2",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
},
|
||||
{
|
||||
ID: "invite-list-3",
|
||||
AccountID: "different-account",
|
||||
Email: "user3@example.com",
|
||||
Name: "User Three",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-list-3",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
},
|
||||
}
|
||||
|
||||
for _, invite := range invites {
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Get all invites for the account
|
||||
retrieved, err := store.GetAccountUserInvites(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, retrieved, 2)
|
||||
|
||||
// Verify the invites belong to the correct account
|
||||
for _, invite := range retrieved {
|
||||
assert.Equal(t, accountID, invite.AccountID)
|
||||
}
|
||||
|
||||
// Get invites for account with no invites
|
||||
retrieved, err = store.GetAccountUserInvites(ctx, LockingStrengthNone, "empty-account")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, retrieved, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteUserInvite(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-delete",
|
||||
AccountID: "account-delete-test",
|
||||
Email: "delete@example.com",
|
||||
Name: "Delete User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-delete",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify invite exists
|
||||
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the invite
|
||||
err = store.DeleteUserInvite(ctx, invite.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify invite is deleted
|
||||
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_UserInvite_EncryptedFields(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-encrypted",
|
||||
AccountID: "account-encrypted",
|
||||
Email: "sensitive-email@example.com",
|
||||
Name: "Sensitive Name",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1"},
|
||||
HashedToken: "hashed-token-encrypted",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve and verify decryption works
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sensitive-email@example.com", retrieved.Email)
|
||||
assert.Equal(t, "Sensitive Name", retrieved.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteUserInvite_NonExistent(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
// Deleting a non-existent invite should not return an error
|
||||
err := store.DeleteUserInvite(ctx, "non-existent-invite-id")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_UserInvite_SameEmailDifferentAccounts(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
email := "shared-email@example.com"
|
||||
|
||||
// Create invite in first account
|
||||
invite1 := &types.UserInviteRecord{
|
||||
ID: "invite-account1",
|
||||
AccountID: "account-1",
|
||||
Email: email,
|
||||
Name: "User Account 1",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-account1",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-1",
|
||||
}
|
||||
|
||||
// Create invite in second account with same email
|
||||
invite2 := &types.UserInviteRecord{
|
||||
ID: "invite-account2",
|
||||
AccountID: "account-2",
|
||||
Email: email,
|
||||
Name: "User Account 2",
|
||||
Role: "admin",
|
||||
AutoGroups: []string{"group-1"},
|
||||
HashedToken: "hashed-token-account2",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-2",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SaveUserInvite(ctx, invite2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify each account gets the correct invite by email
|
||||
retrieved1, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-1", email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "invite-account1", retrieved1.ID)
|
||||
assert.Equal(t, "User Account 1", retrieved1.Name)
|
||||
|
||||
retrieved2, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-2", email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "invite-account2", retrieved2.ID)
|
||||
assert.Equal(t, "User Account 2", retrieved2.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_UserInvite_LockingStrength(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-locking",
|
||||
AccountID: "account-locking",
|
||||
Email: "locking@example.com",
|
||||
Name: "Locking Test User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-locking",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with different locking strengths
|
||||
lockStrengths := []LockingStrength{LockingStrengthNone, LockingStrengthShare, LockingStrengthUpdate}
|
||||
|
||||
for _, strength := range lockStrengths {
|
||||
retrieved, err := store.GetUserInviteByID(ctx, strength, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
|
||||
retrieved, err = store.GetUserInviteByHashedToken(ctx, strength, invite.HashedToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
|
||||
retrieved, err = store.GetUserInviteByEmail(ctx, strength, invite.AccountID, invite.Email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
|
||||
invites, err := store.GetAccountUserInvites(ctx, strength, invite.AccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, invites, 1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_UserInvite_EmptyAutoGroups(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
// Test with nil AutoGroups
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-nil-autogroups",
|
||||
AccountID: "account-autogroups",
|
||||
Email: "nilgroups@example.com",
|
||||
Name: "Nil Groups User",
|
||||
Role: "user",
|
||||
AutoGroups: nil,
|
||||
HashedToken: "hashed-token-nil",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
// Should return empty slice or nil, both are acceptable
|
||||
assert.Empty(t, retrieved.AutoGroups)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_UserInvite_TimestampPrecision(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC().Truncate(time.Millisecond)
|
||||
expiresAt := now.Add(72 * time.Hour)
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-timestamp",
|
||||
AccountID: "account-timestamp",
|
||||
Email: "timestamp@example.com",
|
||||
Name: "Timestamp User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-timestamp",
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedAt: now,
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify timestamps are preserved (within reasonable precision)
|
||||
assert.WithinDuration(t, now, retrieved.CreatedAt, time.Second)
|
||||
assert.WithinDuration(t, expiresAt, retrieved.ExpiresAt, time.Second)
|
||||
})
|
||||
}
|
||||
@@ -92,13 +92,6 @@ type Store interface {
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
|
||||
SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error
|
||||
GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error)
|
||||
GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error)
|
||||
GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error)
|
||||
GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error)
|
||||
DeleteUserInvite(ctx context.Context, inviteID string) error
|
||||
|
||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
||||
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
b "github.com/hashicorp/go-secure-stdlib/base62"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// InviteTokenPrefix is the prefix for invite tokens
|
||||
InviteTokenPrefix = "nbi_"
|
||||
// InviteTokenSecretLength is the length of the random secret part
|
||||
InviteTokenSecretLength = 30
|
||||
// InviteTokenChecksumLength is the length of the encoded checksum
|
||||
InviteTokenChecksumLength = 6
|
||||
// InviteTokenLength is the total length of the token (4 + 30 + 6 = 40)
|
||||
InviteTokenLength = 40
|
||||
// DefaultInviteExpirationSeconds is the default expiration time for invites (72 hours)
|
||||
DefaultInviteExpirationSeconds = 259200
|
||||
// MinInviteExpirationSeconds is the minimum expiration time for invites (1 hour)
|
||||
MinInviteExpirationSeconds = 3600
|
||||
)
|
||||
|
||||
// UserInviteRecord represents an invitation for a user to set up their account (database model)
|
||||
type UserInviteRecord struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index;not null"`
|
||||
Email string `gorm:"index;not null"`
|
||||
Name string `gorm:"not null"`
|
||||
Role string `gorm:"not null"`
|
||||
AutoGroups []string `gorm:"serializer:json"`
|
||||
HashedToken string `gorm:"index;not null"` // SHA-256 hash of the token (base64 encoded)
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
CreatedBy string `gorm:"not null"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
func (UserInviteRecord) TableName() string {
|
||||
return "user_invites"
|
||||
}
|
||||
|
||||
// GenerateInviteToken creates a new invite token with the format: nbi_<secret><checksum>
|
||||
// Returns the hashed token (for storage) and the plain token (to give to the user)
|
||||
func GenerateInviteToken() (hashedToken string, plainToken string, err error) {
|
||||
secret, err := b.Random(InviteTokenSecretLength)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate random secret: %w", err)
|
||||
}
|
||||
|
||||
checksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
encodedChecksum := base62.Encode(checksum)
|
||||
// Left-pad with '0' to ensure exactly 6 characters (fmt.Sprintf %s pads with spaces which breaks base62.Decode)
|
||||
paddedChecksum := encodedChecksum
|
||||
if len(paddedChecksum) < InviteTokenChecksumLength {
|
||||
paddedChecksum = strings.Repeat("0", InviteTokenChecksumLength-len(paddedChecksum)) + paddedChecksum
|
||||
}
|
||||
|
||||
plainToken = InviteTokenPrefix + secret + paddedChecksum
|
||||
hash := sha256.Sum256([]byte(plainToken))
|
||||
hashedToken = b64.StdEncoding.EncodeToString(hash[:])
|
||||
|
||||
return hashedToken, plainToken, nil
|
||||
}
|
||||
|
||||
// HashInviteToken creates a SHA-256 hash of the token (base64 encoded)
|
||||
func HashInviteToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return b64.StdEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// ValidateInviteToken validates the token format and checksum.
|
||||
// Returns an error if the token is invalid.
|
||||
func ValidateInviteToken(token string) error {
|
||||
if len(token) != InviteTokenLength {
|
||||
return fmt.Errorf("invalid token length")
|
||||
}
|
||||
|
||||
prefix := token[:len(InviteTokenPrefix)]
|
||||
if prefix != InviteTokenPrefix {
|
||||
return fmt.Errorf("invalid token prefix")
|
||||
}
|
||||
|
||||
secret := token[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength]
|
||||
encodedChecksum := token[len(InviteTokenPrefix)+InviteTokenSecretLength:]
|
||||
|
||||
verificationChecksum, err := base62.Decode(encodedChecksum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checksum decoding failed: %w", err)
|
||||
}
|
||||
|
||||
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
if secretChecksum != verificationChecksum {
|
||||
return fmt.Errorf("checksum does not match")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExpired checks if the invite has expired
|
||||
func (i *UserInviteRecord) IsExpired() bool {
|
||||
return time.Now().After(i.ExpiresAt)
|
||||
}
|
||||
|
||||
// UserInvite contains the result of creating or regenerating an invite
|
||||
type UserInvite struct {
|
||||
UserInfo *UserInfo
|
||||
InviteToken string
|
||||
InviteExpiresAt time.Time
|
||||
InviteCreatedAt time.Time
|
||||
}
|
||||
|
||||
// UserInviteInfo contains public information about an invite (for unauthenticated endpoint)
|
||||
type UserInviteInfo struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Valid bool `json:"valid"`
|
||||
InvitedBy string `json:"invited_by"`
|
||||
}
|
||||
|
||||
// NewInviteID generates a new invite ID using xid
|
||||
func NewInviteID() string {
|
||||
return xid.New().String()
|
||||
}
|
||||
|
||||
// EncryptSensitiveData encrypts the invite's sensitive fields (Email and Name) in place.
|
||||
func (i *UserInviteRecord) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if i.Email != "" {
|
||||
i.Email, err = enc.Encrypt(i.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if i.Name != "" {
|
||||
i.Name, err = enc.Encrypt(i.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptSensitiveData decrypts the invite's sensitive fields (Email and Name) in place.
|
||||
func (i *UserInviteRecord) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if i.Email != "" {
|
||||
i.Email, err = enc.Decrypt(i.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if i.Name != "" {
|
||||
i.Name, err = enc.Decrypt(i.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of the UserInviteRecord
|
||||
func (i *UserInviteRecord) Copy() *UserInviteRecord {
|
||||
autoGroups := make([]string, len(i.AutoGroups))
|
||||
copy(autoGroups, i.AutoGroups)
|
||||
|
||||
return &UserInviteRecord{
|
||||
ID: i.ID,
|
||||
AccountID: i.AccountID,
|
||||
Email: i.Email,
|
||||
Name: i.Name,
|
||||
Role: i.Role,
|
||||
AutoGroups: autoGroups,
|
||||
HashedToken: i.HashedToken,
|
||||
ExpiresAt: i.ExpiresAt,
|
||||
CreatedAt: i.CreatedAt,
|
||||
CreatedBy: i.CreatedBy,
|
||||
}
|
||||
}
|
||||
@@ -1,355 +0,0 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"hash/crc32"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
func TestUserInviteRecord_TableName(t *testing.T) {
|
||||
invite := UserInviteRecord{}
|
||||
assert.Equal(t, "user_invites", invite.TableName())
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Success(t *testing.T) {
|
||||
hashedToken, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, hashedToken)
|
||||
assert.NotEmpty(t, plainToken)
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Length(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, plainToken, InviteTokenLength)
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Prefix(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, strings.HasPrefix(plainToken, InviteTokenPrefix))
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Hashing(t *testing.T) {
|
||||
hashedToken, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedHash := sha256.Sum256([]byte(plainToken))
|
||||
expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:])
|
||||
assert.Equal(t, expectedHashedToken, hashedToken)
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Checksum(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract parts
|
||||
secret := plainToken[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength]
|
||||
checksumStr := plainToken[len(InviteTokenPrefix)+InviteTokenSecretLength:]
|
||||
|
||||
// Verify checksum
|
||||
expectedChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
actualChecksum, err := base62.Decode(checksumStr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedChecksum, actualChecksum)
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Uniqueness(t *testing.T) {
|
||||
tokens := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
assert.False(t, tokens[plainToken], "Token should be unique")
|
||||
tokens[plainToken] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashInviteToken(t *testing.T) {
|
||||
token := "nbi_testtoken123456789012345678901234"
|
||||
hashedToken := HashInviteToken(token)
|
||||
|
||||
expectedHash := sha256.Sum256([]byte(token))
|
||||
expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:])
|
||||
assert.Equal(t, expectedHashedToken, hashedToken)
|
||||
}
|
||||
|
||||
func TestHashInviteToken_Consistency(t *testing.T) {
|
||||
token := "nbi_testtoken123456789012345678901234"
|
||||
hash1 := HashInviteToken(token)
|
||||
hash2 := HashInviteToken(token)
|
||||
assert.Equal(t, hash1, hash2)
|
||||
}
|
||||
|
||||
func TestHashInviteToken_DifferentTokens(t *testing.T) {
|
||||
token1 := "nbi_testtoken123456789012345678901234"
|
||||
token2 := "nbi_testtoken123456789012345678901235"
|
||||
hash1 := HashInviteToken(token1)
|
||||
hash2 := HashInviteToken(token2)
|
||||
assert.NotEqual(t, hash1, hash2)
|
||||
}
|
||||
|
||||
func TestValidateInviteToken_Success(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ValidateInviteToken(plainToken)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateInviteToken_InvalidLength(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"too short", "nbi_abc"},
|
||||
{"too long", "nbi_" + strings.Repeat("a", 50)},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateInviteToken(tc.token)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid token length")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateInviteToken_InvalidPrefix(t *testing.T) {
|
||||
// Create a token with wrong prefix but correct length
|
||||
token := "xyz_" + strings.Repeat("a", 30) + "000000"
|
||||
err := ValidateInviteToken(token)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid token prefix")
|
||||
}
|
||||
|
||||
func TestValidateInviteToken_InvalidChecksum(t *testing.T) {
|
||||
// Create a token with correct format but invalid checksum
|
||||
token := InviteTokenPrefix + strings.Repeat("a", InviteTokenSecretLength) + "ZZZZZZ"
|
||||
err := ValidateInviteToken(token)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "checksum")
|
||||
}
|
||||
|
||||
func TestValidateInviteToken_ModifiedToken(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Modify one character in the secret part
|
||||
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
|
||||
err = ValidateInviteToken(modifiedToken)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUserInviteRecord_IsExpired(t *testing.T) {
|
||||
t.Run("not expired", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
assert.False(t, invite.IsExpired())
|
||||
})
|
||||
|
||||
t.Run("expired", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ExpiresAt: time.Now().Add(-time.Hour),
|
||||
}
|
||||
assert.True(t, invite.IsExpired())
|
||||
})
|
||||
|
||||
t.Run("just expired", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ExpiresAt: time.Now().Add(-time.Second),
|
||||
}
|
||||
assert.True(t, invite.IsExpired())
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewInviteID(t *testing.T) {
|
||||
id := NewInviteID()
|
||||
assert.NotEmpty(t, id)
|
||||
assert.Len(t, id, 20) // xid generates 20 character IDs
|
||||
}
|
||||
|
||||
func TestNewInviteID_Uniqueness(t *testing.T) {
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
id := NewInviteID()
|
||||
assert.False(t, ids[id], "ID should be unique")
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserInviteRecord_EncryptDecryptSensitiveData(t *testing.T) {
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("encrypt and decrypt", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ID: "test-invite",
|
||||
AccountID: "test-account",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
err := invite.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify encrypted values are different from original
|
||||
assert.NotEqual(t, "test@example.com", invite.Email)
|
||||
assert.NotEqual(t, "Test User", invite.Name)
|
||||
|
||||
// Decrypt
|
||||
err = invite.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify decrypted values match original
|
||||
assert.Equal(t, "test@example.com", invite.Email)
|
||||
assert.Equal(t, "Test User", invite.Name)
|
||||
})
|
||||
|
||||
t.Run("encrypt empty fields", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ID: "test-invite",
|
||||
AccountID: "test-account",
|
||||
Email: "",
|
||||
Name: "",
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
err := invite.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", invite.Email)
|
||||
assert.Equal(t, "", invite.Name)
|
||||
|
||||
err = invite.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", invite.Email)
|
||||
assert.Equal(t, "", invite.Name)
|
||||
})
|
||||
|
||||
t.Run("nil encryptor", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ID: "test-invite",
|
||||
AccountID: "test-account",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
err := invite.EncryptSensitiveData(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test@example.com", invite.Email)
|
||||
assert.Equal(t, "Test User", invite.Name)
|
||||
|
||||
err = invite.DecryptSensitiveData(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test@example.com", invite.Email)
|
||||
assert.Equal(t, "Test User", invite.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserInviteRecord_Copy(t *testing.T) {
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(72 * time.Hour)
|
||||
|
||||
original := &UserInviteRecord{
|
||||
ID: "invite-id",
|
||||
AccountID: "account-id",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
HashedToken: "hashed-token",
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedAt: now,
|
||||
CreatedBy: "creator-id",
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
|
||||
// Verify all fields are copied
|
||||
assert.Equal(t, original.ID, copied.ID)
|
||||
assert.Equal(t, original.AccountID, copied.AccountID)
|
||||
assert.Equal(t, original.Email, copied.Email)
|
||||
assert.Equal(t, original.Name, copied.Name)
|
||||
assert.Equal(t, original.Role, copied.Role)
|
||||
assert.Equal(t, original.AutoGroups, copied.AutoGroups)
|
||||
assert.Equal(t, original.HashedToken, copied.HashedToken)
|
||||
assert.Equal(t, original.ExpiresAt, copied.ExpiresAt)
|
||||
assert.Equal(t, original.CreatedAt, copied.CreatedAt)
|
||||
assert.Equal(t, original.CreatedBy, copied.CreatedBy)
|
||||
|
||||
// Verify deep copy of AutoGroups (modifying copy doesn't affect original)
|
||||
copied.AutoGroups[0] = "modified"
|
||||
assert.NotEqual(t, original.AutoGroups[0], copied.AutoGroups[0])
|
||||
assert.Equal(t, "group1", original.AutoGroups[0])
|
||||
}
|
||||
|
||||
func TestUserInviteRecord_Copy_EmptyAutoGroups(t *testing.T) {
|
||||
original := &UserInviteRecord{
|
||||
ID: "invite-id",
|
||||
AccountID: "account-id",
|
||||
AutoGroups: []string{},
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
assert.NotNil(t, copied.AutoGroups)
|
||||
assert.Len(t, copied.AutoGroups, 0)
|
||||
}
|
||||
|
||||
func TestUserInviteRecord_Copy_NilAutoGroups(t *testing.T) {
|
||||
original := &UserInviteRecord{
|
||||
ID: "invite-id",
|
||||
AccountID: "account-id",
|
||||
AutoGroups: nil,
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
assert.NotNil(t, copied.AutoGroups)
|
||||
assert.Len(t, copied.AutoGroups, 0)
|
||||
}
|
||||
|
||||
func TestInviteTokenConstants(t *testing.T) {
|
||||
// Verify constants are consistent
|
||||
expectedLength := len(InviteTokenPrefix) + InviteTokenSecretLength + InviteTokenChecksumLength
|
||||
assert.Equal(t, InviteTokenLength, expectedLength)
|
||||
assert.Equal(t, 4, len(InviteTokenPrefix))
|
||||
assert.Equal(t, 30, InviteTokenSecretLength)
|
||||
assert.Equal(t, 6, InviteTokenChecksumLength)
|
||||
assert.Equal(t, 40, InviteTokenLength)
|
||||
assert.Equal(t, 259200, DefaultInviteExpirationSeconds) // 72 hours
|
||||
assert.Equal(t, 3600, MinInviteExpirationSeconds) // 1 hour
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_ValidatesOwnOutput(t *testing.T) {
|
||||
// Generate multiple tokens and ensure they all validate
|
||||
for i := 0; i < 50; i++ {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ValidateInviteToken(plainToken)
|
||||
assert.NoError(t, err, "Generated token should always be valid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashInviteToken_MatchesGeneratedHash(t *testing.T) {
|
||||
hashedToken, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
// HashInviteToken should produce the same hash as GenerateInviteToken
|
||||
rehashedToken := HashInviteToken(plainToken)
|
||||
assert.Equal(t, hashedToken, rehashedToken)
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
@@ -1454,368 +1453,3 @@ func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, init
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateUserInvite creates an invite link for a new user in the embedded IdP.
|
||||
// The user is NOT created until the invite is accepted.
|
||||
func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
if !IsEmbeddedIdp(am.idpManager) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
if err := validateUserInvite(invite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Check if user already exists in NetBird DB
|
||||
existingUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, user := range existingUsers {
|
||||
if strings.EqualFold(user.Email, invite.Email) {
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
|
||||
}
|
||||
}
|
||||
|
||||
// Check if invite already exists for this email
|
||||
existingInvite, err := am.Store.GetUserInviteByEmail(ctx, store.LockingStrengthNone, accountID, invite.Email)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return nil, fmt.Errorf("failed to check existing invites: %w", err)
|
||||
}
|
||||
}
|
||||
if existingInvite != nil {
|
||||
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
|
||||
}
|
||||
|
||||
// Calculate expiration time
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = types.DefaultInviteExpirationSeconds
|
||||
}
|
||||
|
||||
if expiresIn < types.MinInviteExpirationSeconds {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour")
|
||||
}
|
||||
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
// Generate invite token
|
||||
inviteID := types.NewInviteID()
|
||||
hashedToken, plainToken, err := types.GenerateInviteToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate invite token: %w", err)
|
||||
}
|
||||
|
||||
// Create the invite record (no user created yet)
|
||||
userInvite := &types.UserInviteRecord{
|
||||
ID: inviteID,
|
||||
AccountID: accountID,
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
Role: invite.Role,
|
||||
AutoGroups: invite.AutoGroups,
|
||||
HashedToken: hashedToken,
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
CreatedBy: initiatorUserID,
|
||||
}
|
||||
|
||||
if err := am.Store.SaveUserInvite(ctx, userInvite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkCreated, map[string]any{"email": invite.Email})
|
||||
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: inviteID,
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
Role: invite.Role,
|
||||
AutoGroups: invite.AutoGroups,
|
||||
Status: string(types.UserStatusInvited),
|
||||
Issued: types.UserIssuedAPI,
|
||||
},
|
||||
InviteToken: plainToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserInviteInfo retrieves invite information from a token (public endpoint).
|
||||
func (am *DefaultAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
if err := types.ValidateInviteToken(token); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid invite token: %v", err)
|
||||
}
|
||||
|
||||
hashedToken := types.HashInviteToken(token)
|
||||
invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthNone, hashedToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the inviter's name
|
||||
invitedBy := ""
|
||||
if invite.CreatedBy != "" {
|
||||
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, invite.CreatedBy)
|
||||
if err == nil && inviter != nil {
|
||||
invitedBy = inviter.Name
|
||||
}
|
||||
}
|
||||
|
||||
return &types.UserInviteInfo{
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
ExpiresAt: invite.ExpiresAt,
|
||||
Valid: !invite.IsExpired(),
|
||||
InvitedBy: invitedBy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListUserInvites returns all invites for an account.
|
||||
func (am *DefaultAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
if !IsEmbeddedIdp(am.idpManager) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
records, err := am.Store.GetAccountUserInvites(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
invites := make([]*types.UserInvite, 0, len(records))
|
||||
for _, record := range records {
|
||||
invites = append(invites, &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: record.ID,
|
||||
Email: record.Email,
|
||||
Name: record.Name,
|
||||
Role: record.Role,
|
||||
AutoGroups: record.AutoGroups,
|
||||
},
|
||||
InviteExpiresAt: record.ExpiresAt,
|
||||
InviteCreatedAt: record.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
return invites, nil
|
||||
}
|
||||
|
||||
// AcceptUserInvite accepts an invite and creates the user in both IdP and NetBird DB.
|
||||
func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error {
|
||||
if !IsEmbeddedIdp(am.idpManager) {
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
if password == "" {
|
||||
return status.Errorf(status.InvalidArgument, "password is required")
|
||||
}
|
||||
|
||||
if err := validatePassword(password); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "invalid password: %v", err)
|
||||
}
|
||||
|
||||
if err := types.ValidateInviteToken(token); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "invalid invite token: %v", err)
|
||||
}
|
||||
|
||||
hashedToken := types.HashInviteToken(token)
|
||||
invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthUpdate, hashedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if invite.IsExpired() {
|
||||
return status.Errorf(status.InvalidArgument, "invite has expired")
|
||||
}
|
||||
|
||||
// Create user in Dex with the provided password
|
||||
embeddedIdp, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
if !ok {
|
||||
return status.Errorf(status.Internal, "failed to get embedded IdP manager")
|
||||
}
|
||||
|
||||
idpUser, err := embeddedIdp.CreateUserWithPassword(ctx, invite.Email, password, invite.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create user in IdP: %w", err)
|
||||
}
|
||||
|
||||
// Create user in NetBird DB
|
||||
newUser := &types.User{
|
||||
Id: idpUser.ID,
|
||||
AccountID: invite.AccountID,
|
||||
Role: types.StrRoleToUserRole(invite.Role),
|
||||
AutoGroups: invite.AutoGroups,
|
||||
Issued: types.UserIssuedAPI,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.SaveUser(ctx, newUser); err != nil {
|
||||
return fmt.Errorf("failed to save user: %w", err)
|
||||
}
|
||||
if err := transaction.DeleteUserInvite(ctx, invite.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete invite: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// Best-effort rollback: delete the IdP user to avoid orphaned records
|
||||
if deleteErr := embeddedIdp.DeleteUser(ctx, idpUser.ID); deleteErr != nil {
|
||||
log.WithContext(ctx).WithError(deleteErr).Errorf("failed to rollback IdP user %s after transaction failure", idpUser.ID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, newUser.Id, newUser.Id, invite.AccountID, activity.UserInviteLinkAccepted, map[string]any{"email": invite.Email})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegenerateUserInvite creates a new invite token for an existing invite, invalidating the previous one.
|
||||
func (am *DefaultAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
if !IsEmbeddedIdp(am.idpManager) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Get existing invite
|
||||
existingInvite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Calculate expiration time
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = types.DefaultInviteExpirationSeconds
|
||||
}
|
||||
if expiresIn < types.MinInviteExpirationSeconds {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour")
|
||||
}
|
||||
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
// Generate new invite token
|
||||
hashedToken, plainToken, err := types.GenerateInviteToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate invite token: %w", err)
|
||||
}
|
||||
|
||||
// Update existing invite with new token and expiration
|
||||
existingInvite.HashedToken = hashedToken
|
||||
existingInvite.ExpiresAt = expiresAt
|
||||
existingInvite.CreatedBy = initiatorUserID
|
||||
|
||||
err = am.Store.SaveUserInvite(ctx, existingInvite)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, existingInvite.ID, accountID, activity.UserInviteLinkRegenerated, map[string]any{"email": existingInvite.Email})
|
||||
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: existingInvite.ID,
|
||||
Email: existingInvite.Email,
|
||||
Name: existingInvite.Name,
|
||||
Role: existingInvite.Role,
|
||||
AutoGroups: existingInvite.AutoGroups,
|
||||
Status: string(types.UserStatusInvited),
|
||||
Issued: types.UserIssuedAPI,
|
||||
},
|
||||
InviteToken: plainToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeleteUserInvite deletes an existing invite by ID.
|
||||
func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
if !IsEmbeddedIdp(am.idpManager) {
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
invite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := am.Store.DeleteUserInvite(ctx, inviteID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkDeleted, map[string]any{"email": invite.Email})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const minPasswordLength = 8
|
||||
|
||||
// validatePassword checks password strength requirements:
|
||||
// - Minimum 8 characters
|
||||
// - At least 1 digit
|
||||
// - At least 1 uppercase letter
|
||||
// - At least 1 special character
|
||||
func validatePassword(password string) error {
|
||||
if len(password) < minPasswordLength {
|
||||
return errors.New("password must be at least 8 characters long")
|
||||
}
|
||||
|
||||
var hasDigit, hasUpper, hasSpecial bool
|
||||
for _, c := range password {
|
||||
switch {
|
||||
case unicode.IsDigit(c):
|
||||
hasDigit = true
|
||||
case unicode.IsUpper(c):
|
||||
hasUpper = true
|
||||
case !unicode.IsLetter(c) && !unicode.IsDigit(c):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
var missing []string
|
||||
if !hasDigit {
|
||||
missing = append(missing, "one digit")
|
||||
}
|
||||
if !hasUpper {
|
||||
missing = append(missing, "one uppercase letter")
|
||||
}
|
||||
if !hasSpecial {
|
||||
missing = append(missing, "one special character")
|
||||
}
|
||||
|
||||
if len(missing) > 0 {
|
||||
return errors.New("password must contain at least " + strings.Join(missing, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -72,8 +72,8 @@ var (
|
||||
|
||||
func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator {
|
||||
keys, err := getPemKeys(keysLocation)
|
||||
if err != nil && !strings.Contains(keysLocation, "localhost") {
|
||||
log.WithField("keysLocation", keysLocation).Warnf("could not get keys from location: %s, it will try again on the next http request", err)
|
||||
if err != nil {
|
||||
log.WithField("keysLocation", keysLocation).Errorf("could not get keys from location: %s", err)
|
||||
}
|
||||
|
||||
return &Validator{
|
||||
|
||||
@@ -488,171 +488,6 @@ components:
|
||||
- role
|
||||
- auto_groups
|
||||
- is_service_user
|
||||
UserInviteCreateRequest:
|
||||
type: object
|
||||
description: Request to create a user invite link
|
||||
properties:
|
||||
email:
|
||||
description: User's email address
|
||||
type: string
|
||||
example: user@example.com
|
||||
name:
|
||||
description: User's full name
|
||||
type: string
|
||||
example: John Doe
|
||||
role:
|
||||
description: User's NetBird account role
|
||||
type: string
|
||||
example: user
|
||||
auto_groups:
|
||||
description: Group IDs to auto-assign to peers registered by this user
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7m0
|
||||
expires_in:
|
||||
description: Invite expiration time in seconds (default 72 hours)
|
||||
type: integer
|
||||
example: 259200
|
||||
required:
|
||||
- email
|
||||
- name
|
||||
- role
|
||||
- auto_groups
|
||||
UserInvite:
|
||||
type: object
|
||||
description: A user invite
|
||||
properties:
|
||||
id:
|
||||
description: Invite ID
|
||||
type: string
|
||||
example: d5p7eedra0h0lt6f59hg
|
||||
email:
|
||||
description: User's email address
|
||||
type: string
|
||||
example: user@example.com
|
||||
name:
|
||||
description: User's full name
|
||||
type: string
|
||||
example: John Doe
|
||||
role:
|
||||
description: User's NetBird account role
|
||||
type: string
|
||||
example: user
|
||||
auto_groups:
|
||||
description: Group IDs to auto-assign to peers registered by this user
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7m0
|
||||
expires_at:
|
||||
description: Invite expiration time
|
||||
type: string
|
||||
format: date-time
|
||||
example: "2024-01-25T10:00:00Z"
|
||||
created_at:
|
||||
description: Invite creation time
|
||||
type: string
|
||||
format: date-time
|
||||
example: "2024-01-22T10:00:00Z"
|
||||
expired:
|
||||
description: Whether the invite has expired
|
||||
type: boolean
|
||||
example: false
|
||||
invite_token:
|
||||
description: The invite link to be shared with the user. Only returned when the invite is created or regenerated.
|
||||
type: string
|
||||
example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123
|
||||
required:
|
||||
- id
|
||||
- email
|
||||
- name
|
||||
- role
|
||||
- auto_groups
|
||||
- expires_at
|
||||
- created_at
|
||||
- expired
|
||||
UserInviteInfo:
|
||||
type: object
|
||||
description: Public information about an invite
|
||||
properties:
|
||||
email:
|
||||
description: User's email address
|
||||
type: string
|
||||
example: user@example.com
|
||||
name:
|
||||
description: User's full name
|
||||
type: string
|
||||
example: John Doe
|
||||
expires_at:
|
||||
description: Invite expiration time
|
||||
type: string
|
||||
format: date-time
|
||||
example: "2024-01-25T10:00:00Z"
|
||||
valid:
|
||||
description: Whether the invite is still valid (not expired)
|
||||
type: boolean
|
||||
example: true
|
||||
invited_by:
|
||||
description: Name of the user who sent the invite
|
||||
type: string
|
||||
example: Admin User
|
||||
required:
|
||||
- email
|
||||
- name
|
||||
- expires_at
|
||||
- valid
|
||||
- invited_by
|
||||
UserInviteAcceptRequest:
|
||||
type: object
|
||||
description: Request to accept an invite and set password
|
||||
properties:
|
||||
password:
|
||||
description: >-
|
||||
The password the user wants to set. Must be at least 8 characters long
|
||||
and contain at least one uppercase letter, one digit, and one special
|
||||
character (any character that is not a letter or digit, including spaces).
|
||||
type: string
|
||||
format: password
|
||||
minLength: 8
|
||||
pattern: '^(?=.*[0-9])(?=.*[A-Z])(?=.*[^a-zA-Z0-9]).{8,}$'
|
||||
example: SecurePass123!
|
||||
required:
|
||||
- password
|
||||
UserInviteAcceptResponse:
|
||||
type: object
|
||||
description: Response after accepting an invite
|
||||
properties:
|
||||
success:
|
||||
description: Whether the invite was accepted successfully
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
- success
|
||||
UserInviteRegenerateRequest:
|
||||
type: object
|
||||
description: Request to regenerate an invite link
|
||||
properties:
|
||||
expires_in:
|
||||
description: Invite expiration time in seconds (default 72 hours)
|
||||
type: integer
|
||||
example: 259200
|
||||
UserInviteRegenerateResponse:
|
||||
type: object
|
||||
description: Response after regenerating an invite
|
||||
properties:
|
||||
invite_token:
|
||||
description: The new invite token
|
||||
type: string
|
||||
example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123
|
||||
invite_expires_at:
|
||||
description: New invite expiration time
|
||||
type: string
|
||||
format: date-time
|
||||
example: "2024-01-28T10:00:00Z"
|
||||
required:
|
||||
- invite_token
|
||||
- invite_expires_at
|
||||
PeerMinimum:
|
||||
type: object
|
||||
properties:
|
||||
@@ -2236,8 +2071,7 @@ components:
|
||||
"dns.zone.create", "dns.zone.update", "dns.zone.delete",
|
||||
"dns.zone.record.create", "dns.zone.record.update", "dns.zone.record.delete",
|
||||
"peer.job.create",
|
||||
"user.password.change",
|
||||
"user.invite.link.create", "user.invite.link.accept", "user.invite.link.regenerate", "user.invite.link.delete"
|
||||
"user.password.change"
|
||||
]
|
||||
example: route.add
|
||||
initiator_id:
|
||||
@@ -2808,29 +2642,6 @@ components:
|
||||
required:
|
||||
- user_id
|
||||
- email
|
||||
InstanceVersionInfo:
|
||||
type: object
|
||||
description: Version information for NetBird components
|
||||
properties:
|
||||
management_current_version:
|
||||
description: The current running version of the management server
|
||||
type: string
|
||||
example: "0.35.0"
|
||||
dashboard_available_version:
|
||||
description: The latest available version of the dashboard (from GitHub releases)
|
||||
type: string
|
||||
example: "2.10.0"
|
||||
management_available_version:
|
||||
description: The latest available version of the management server (from GitHub releases)
|
||||
type: string
|
||||
example: "0.35.0"
|
||||
management_update_available:
|
||||
description: Indicates if a newer management version is available
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
- management_current_version
|
||||
- management_update_available
|
||||
responses:
|
||||
not_found:
|
||||
description: Resource not found
|
||||
@@ -2883,27 +2694,6 @@ paths:
|
||||
$ref: '#/components/schemas/InstanceStatus'
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/instance/version:
|
||||
get:
|
||||
summary: Get Version Info
|
||||
description: Returns version information for NetBird components including the current management server version and latest available versions from GitHub.
|
||||
tags: [ Instance ]
|
||||
security:
|
||||
- BearerAuth: []
|
||||
- TokenAuth: []
|
||||
responses:
|
||||
'200':
|
||||
description: Version information
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/InstanceVersionInfo'
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/setup:
|
||||
post:
|
||||
summary: Setup Instance
|
||||
@@ -3522,210 +3312,6 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/invites:
|
||||
get:
|
||||
summary: List user invites
|
||||
description: Lists all pending invites for the account. Only available when embedded IdP is enabled.
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
responses:
|
||||
'200':
|
||||
description: List of invites
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/UserInvite'
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'412':
|
||||
description: Precondition failed - embedded IdP is not enabled
|
||||
content: { }
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
post:
|
||||
summary: Create a user invite
|
||||
description: Creates an invite link for a new user. Only available when embedded IdP is enabled. The user is not created until they accept the invite.
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
requestBody:
|
||||
description: User invite information
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteCreateRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: Invite created successfully
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInvite'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'409':
|
||||
description: User or invite already exists
|
||||
content: { }
|
||||
'412':
|
||||
description: Precondition failed - embedded IdP is not enabled
|
||||
content: { }
|
||||
'422':
|
||||
"$ref": "#/components/responses/validation_failed"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/invites/{inviteId}:
|
||||
delete:
|
||||
summary: Delete a user invite
|
||||
description: Deletes a pending invite. Only available when embedded IdP is enabled.
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: inviteId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The ID of the invite to delete
|
||||
responses:
|
||||
'200':
|
||||
description: Invite deleted successfully
|
||||
content: { }
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
description: Invite not found
|
||||
content: { }
|
||||
'412':
|
||||
description: Precondition failed - embedded IdP is not enabled
|
||||
content: { }
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/invites/{inviteId}/regenerate:
|
||||
post:
|
||||
summary: Regenerate a user invite
|
||||
description: Regenerates an invite link for an existing invite. Invalidates the previous token and creates a new one.
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: inviteId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The ID of the invite to regenerate
|
||||
requestBody:
|
||||
description: Regenerate options
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteRegenerateRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: Invite regenerated successfully
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteRegenerateResponse'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
description: Invite not found
|
||||
content: { }
|
||||
'412':
|
||||
description: Precondition failed - embedded IdP is not enabled
|
||||
content: { }
|
||||
'422':
|
||||
"$ref": "#/components/responses/validation_failed"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/invites/{token}:
|
||||
get:
|
||||
summary: Get invite information
|
||||
description: Retrieves public information about an invite. This endpoint is unauthenticated and protected by the token itself.
|
||||
tags: [ Users ]
|
||||
security: []
|
||||
parameters:
|
||||
- in: path
|
||||
name: token
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The invite token
|
||||
responses:
|
||||
'200':
|
||||
description: Invite information
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteInfo'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'404':
|
||||
description: Invite not found or invalid token
|
||||
content: { }
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/invites/{token}/accept:
|
||||
post:
|
||||
summary: Accept an invite
|
||||
description: Accepts an invite and creates the user with the provided password. This endpoint is unauthenticated and protected by the token itself.
|
||||
tags: [ Users ]
|
||||
security: []
|
||||
parameters:
|
||||
- in: path
|
||||
name: token
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The invite token
|
||||
requestBody:
|
||||
description: Password to set for the new user
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteAcceptRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: Invite accepted successfully
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteAcceptResponse'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'404':
|
||||
description: Invite not found or invalid token
|
||||
content: { }
|
||||
'412':
|
||||
description: Precondition failed - embedded IdP is not enabled or invite expired
|
||||
content: { }
|
||||
'422':
|
||||
"$ref": "#/components/responses/validation_failed"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/peers:
|
||||
get:
|
||||
summary: List all Peers
|
||||
|
||||
@@ -123,10 +123,6 @@ const (
|
||||
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
|
||||
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
|
||||
EventActivityCodeUserInvite EventActivityCode = "user.invite"
|
||||
EventActivityCodeUserInviteLinkAccept EventActivityCode = "user.invite.link.accept"
|
||||
EventActivityCodeUserInviteLinkCreate EventActivityCode = "user.invite.link.create"
|
||||
EventActivityCodeUserInviteLinkDelete EventActivityCode = "user.invite.link.delete"
|
||||
EventActivityCodeUserInviteLinkRegenerate EventActivityCode = "user.invite.link.regenerate"
|
||||
EventActivityCodeUserJoin EventActivityCode = "user.join"
|
||||
EventActivityCodeUserPasswordChange EventActivityCode = "user.password.change"
|
||||
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
|
||||
@@ -874,21 +870,6 @@ type InstanceStatus struct {
|
||||
SetupRequired bool `json:"setup_required"`
|
||||
}
|
||||
|
||||
// InstanceVersionInfo Version information for NetBird components
|
||||
type InstanceVersionInfo struct {
|
||||
// DashboardAvailableVersion The latest available version of the dashboard (from GitHub releases)
|
||||
DashboardAvailableVersion *string `json:"dashboard_available_version,omitempty"`
|
||||
|
||||
// ManagementAvailableVersion The latest available version of the management server (from GitHub releases)
|
||||
ManagementAvailableVersion *string `json:"management_available_version,omitempty"`
|
||||
|
||||
// ManagementCurrentVersion The current running version of the management server
|
||||
ManagementCurrentVersion string `json:"management_current_version"`
|
||||
|
||||
// ManagementUpdateAvailable Indicates if a newer management version is available
|
||||
ManagementUpdateAvailable bool `json:"management_update_available"`
|
||||
}
|
||||
|
||||
// JobRequest defines model for JobRequest.
|
||||
type JobRequest struct {
|
||||
Workload WorkloadRequest `json:"workload"`
|
||||
@@ -2185,99 +2166,6 @@ type UserCreateRequest struct {
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// UserInvite A user invite
|
||||
type UserInvite struct {
|
||||
// AutoGroups Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// CreatedAt Invite creation time
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// Email User's email address
|
||||
Email string `json:"email"`
|
||||
|
||||
// Expired Whether the invite has expired
|
||||
Expired bool `json:"expired"`
|
||||
|
||||
// ExpiresAt Invite expiration time
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
|
||||
// Id Invite ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// InviteToken The invite link to be shared with the user. Only returned when the invite is created or regenerated.
|
||||
InviteToken *string `json:"invite_token,omitempty"`
|
||||
|
||||
// Name User's full name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Role User's NetBird account role
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// UserInviteAcceptRequest Request to accept an invite and set password
|
||||
type UserInviteAcceptRequest struct {
|
||||
// Password The password the user wants to set. Must be at least 8 characters long and contain at least one uppercase letter, one digit, and one special character (any character that is not a letter or digit, including spaces).
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// UserInviteAcceptResponse Response after accepting an invite
|
||||
type UserInviteAcceptResponse struct {
|
||||
// Success Whether the invite was accepted successfully
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// UserInviteCreateRequest Request to create a user invite link
|
||||
type UserInviteCreateRequest struct {
|
||||
// AutoGroups Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Email User's email address
|
||||
Email string `json:"email"`
|
||||
|
||||
// ExpiresIn Invite expiration time in seconds (default 72 hours)
|
||||
ExpiresIn *int `json:"expires_in,omitempty"`
|
||||
|
||||
// Name User's full name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Role User's NetBird account role
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// UserInviteInfo Public information about an invite
|
||||
type UserInviteInfo struct {
|
||||
// Email User's email address
|
||||
Email string `json:"email"`
|
||||
|
||||
// ExpiresAt Invite expiration time
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
|
||||
// InvitedBy Name of the user who sent the invite
|
||||
InvitedBy string `json:"invited_by"`
|
||||
|
||||
// Name User's full name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Valid Whether the invite is still valid (not expired)
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// UserInviteRegenerateRequest Request to regenerate an invite link
|
||||
type UserInviteRegenerateRequest struct {
|
||||
// ExpiresIn Invite expiration time in seconds (default 72 hours)
|
||||
ExpiresIn *int `json:"expires_in,omitempty"`
|
||||
}
|
||||
|
||||
// UserInviteRegenerateResponse Response after regenerating an invite
|
||||
type UserInviteRegenerateResponse struct {
|
||||
// InviteExpiresAt New invite expiration time
|
||||
InviteExpiresAt time.Time `json:"invite_expires_at"`
|
||||
|
||||
// InviteToken The new invite token
|
||||
InviteToken string `json:"invite_token"`
|
||||
}
|
||||
|
||||
// UserPermissions defines model for UserPermissions.
|
||||
type UserPermissions struct {
|
||||
// IsRestricted Indicates whether this User's Peers view is restricted
|
||||
@@ -2530,15 +2418,6 @@ type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest
|
||||
// PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType.
|
||||
type PostApiUsersJSONRequestBody = UserCreateRequest
|
||||
|
||||
// PostApiUsersInvitesJSONRequestBody defines body for PostApiUsersInvites for application/json ContentType.
|
||||
type PostApiUsersInvitesJSONRequestBody = UserInviteCreateRequest
|
||||
|
||||
// PostApiUsersInvitesInviteIdRegenerateJSONRequestBody defines body for PostApiUsersInvitesInviteIdRegenerate for application/json ContentType.
|
||||
type PostApiUsersInvitesInviteIdRegenerateJSONRequestBody = UserInviteRegenerateRequest
|
||||
|
||||
// PostApiUsersInvitesTokenAcceptJSONRequestBody defines body for PostApiUsersInvitesTokenAccept for application/json ContentType.
|
||||
type PostApiUsersInvitesTokenAcceptJSONRequestBody = UserInviteAcceptRequest
|
||||
|
||||
// PutApiUsersUserIdJSONRequestBody defines body for PutApiUsersUserId for application/json ContentType.
|
||||
type PutApiUsersUserIdJSONRequestBody = UserRequest
|
||||
|
||||
|
||||
Reference in New Issue
Block a user