Compare commits

..

2 Commits

Author SHA1 Message Date
crn4
97ad3307dd Merge branch 'main' into log/conn-disconn 2026-01-23 19:59:25 +01:00
crn4
f6cc27d675 add some logs for conn/disconn status change 2026-01-23 18:01:26 +01:00
81 changed files with 1102 additions and 7072 deletions

View File

@@ -3,7 +3,15 @@ package android
import (
"context"
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -76,21 +84,34 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
return err
}
return err
})
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -108,17 +129,19 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err, _ = authClient.Login(ctxWithValues, setupKey, "")
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
@@ -137,41 +160,49 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
}
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
var needsLogin bool
// check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(a.ctx)
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err, _ = authClient.Login(a.ctx, "", jwtToken)
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
go urlOpener.OnLoginSuccess()
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
return nil
}
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
if err != nil {
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
return nil, err
}
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
@@ -181,10 +212,22 @@ func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

View File

@@ -7,6 +7,7 @@ import (
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -276,19 +277,18 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
needsLogin := false
err, isAuthError := authClient.Login(ctx, "", "")
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
err := WithBackOff(func() error {
err := internal.Login(ctx, config, "", "")
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
needsLogin = true
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
@@ -300,9 +300,23 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken = tokenInfo.GetTokenToUse()
}
err, _ = authClient.Login(ctx, setupKey, jwtToken)
var lastError error
err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})
if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
return nil
@@ -330,7 +344,11 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
defer c()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
@@ -177,13 +176,7 @@ func (c *Client) Start(startCtx context.Context) error {
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
}
defer authClient.Close()
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}

View File

@@ -83,10 +83,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChain(); err != nil {
return fmt.Errorf("init notrack chain: %w", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
@@ -181,10 +177,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
var merr *multierror.Error
if err := m.cleanupNoTrackChain(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
}
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
@@ -285,125 +277,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"
tableRaw = "raw"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
wgPortStr := fmt.Sprintf("%d", wgPort)
proxyPortStr := fmt.Sprintf("%d", proxyPort)
// Egress rules: match outgoing loopback UDP packets
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
return fmt.Errorf("add output sport notrack rule: %w", err)
}
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
return fmt.Errorf("add output dport notrack rule: %w", err)
}
// Ingress rules: match incoming loopback UDP packets
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
}
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChain() error {
if err := m.cleanupNoTrackChain(); err != nil {
log.Debugf("cleanup notrack chain: %v", err)
}
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("create chain: %w", err)
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add output jump rule: %w", err)
}
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
log.Debugf("delete output jump rule: %v", delErr)
}
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add prerouting jump rule: %w", err)
}
return nil
}
func (m *Manager) cleanupNoTrackChain() error {
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
if err != nil {
return fmt.Errorf("check chain exists: %w", err)
}
if !exists {
return nil
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
return fmt.Errorf("remove output jump rule: %w", err)
}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
return fmt.Errorf("remove prerouting jump rule: %w", err)
}
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("clear and delete chain: %w", err)
}
return nil
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -168,10 +168,6 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -12,7 +12,6 @@ import (
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -49,10 +48,8 @@ type Manager struct {
rConn *nftables.Conn
wgIface iFaceMapper
router *router
aclManager *AclManager
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
router *router
aclManager *AclManager
}
// Create nftables firewall manager
@@ -94,10 +91,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChains(workTable); err != nil {
return fmt.Errorf("init notrack chains: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
@@ -295,15 +288,7 @@ func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.aclManager.Flush(); err != nil {
return err
}
if err := m.refreshNoTrackChains(); err != nil {
log.Errorf("failed to refresh notrack chains: %v", err)
}
return nil
return m.aclManager.Flush()
}
// AddDNATRule adds a DNAT rule
@@ -346,176 +331,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
return fmt.Errorf("notrack chains not initialized")
}
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
loopback := []byte{127, 0, 0, 1}
// Egress rules: match outgoing loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
// Ingress rules: match incoming loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
&expr.Counter{},
&expr.Notrack{},
},
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush notrack rules: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawOutput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityRaw,
})
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawPrerouting,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRaw,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush chain creation: %w", err)
}
return nil
}
func (m *Manager) refreshNoTrackChains() error {
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
tableName := getTableName()
for _, c := range chains {
if c.Table.Name != tableName {
continue
}
switch c.Name {
case chainNameRawOutput:
m.notrackOutputChain = c
case chainNameRawPrerouting:
m.notrackPreroutingChain = c
}
}
return nil
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -570,14 +570,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {

View File

@@ -50,7 +50,6 @@ func ValidateMTU(mtu uint16) error {
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
GetProxyPort() uint16
Free() error
}
@@ -81,12 +80,6 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}
// GetProxyPort returns the proxy port used by the WireGuard proxy.
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
func (w *WGIface) GetProxyPort() uint16 {
return w.wgProxyFactory.GetProxyPort()
}
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()

View File

@@ -114,21 +114,34 @@ func (p *ProxyBind) Pause() {
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
ep, err := addrToEndpoint(endpoint)
if err != nil {
log.Errorf("failed to start package redirection: %v", err)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = ep
ep, err := addrToEndpoint(endpoint)
if err != nil {
log.Errorf("failed to convert endpoint address: %v", err)
} else {
p.wgCurrentUsed = ep
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, errors.New("nil address")
}
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
}
func (p *ProxyBind) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
@@ -212,16 +225,3 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
}
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, fmt.Errorf("invalid address")
}
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
}

View File

@@ -27,19 +27,12 @@ const (
)
var (
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
localWGListenPort int
proxyPort int
mtu uint16
ebpfManager ebpfMgr.Manager
@@ -47,8 +40,7 @@ type WGEBPFProxy struct {
turnConnMutex sync.Mutex
lastUsedPort uint16
rawConnIPv4 net.PacketConn
rawConnIPv6 net.PacketConn
rawConn net.PacketConn
conn transport.UDPConn
ctx context.Context
@@ -70,39 +62,23 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
// Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
proxyPort, err := pl.searchFreePort()
if err != nil {
return err
}
p.proxyPort = proxyPort
// Prepare IPv4 raw socket (required)
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
wgPorxyPort, err := pl.searchFreePort()
if err != nil {
return err
}
// Prepare IPv6 raw socket (optional)
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
if err != nil {
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
return err
}
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
if err != nil {
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
}
if p.rawConnIPv6 != nil {
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
}
}
return err
}
addr := net.UDPAddr{
Port: proxyPort,
Port: wgPorxyPort,
IP: net.ParseIP(loopbackAddr),
}
@@ -118,7 +94,7 @@ func (p *WGEBPFProxy) Listen() error {
p.conn = conn
go p.proxyToRemote()
log.Infof("local wg proxy listening on: %d", proxyPort)
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
return nil
}
@@ -159,25 +135,12 @@ func (p *WGEBPFProxy) Free() error {
result = multierror.Append(result, err)
}
if p.rawConnIPv4 != nil {
if err := p.rawConnIPv4.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if p.rawConnIPv6 != nil {
if err := p.rawConnIPv6.Close(); err != nil {
result = multierror.Append(result, err)
}
if err := p.rawConn.Close(); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
}
// GetProxyPort returns the proxy listening port.
func (p *WGEBPFProxy) GetProxyPort() uint16 {
return uint16(p.proxyPort)
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
@@ -255,60 +218,31 @@ generatePort:
}
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var dstIP net.IP
var rawConn net.PacketConn
if endpointAddr.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
dstIP = localHostNetIPv4
rawConn = p.rawConnIPv4
} else {
// IPv6 path
if p.rawConnIPv6 == nil {
return fmt.Errorf("IPv6 raw socket not available")
}
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpointAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
dstIP = localHostNetIPv6
rawConn = p.rawConnIPv6
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localHostNetIP,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return fmt.Errorf("set network layer for checksum: %w", err)
}
layerBuffer := gopacket.NewSerializeBuffer()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil {
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil {
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil

View File

@@ -41,7 +41,7 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
@@ -91,14 +91,12 @@ func (p *ProxyWrapper) Pause() {
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
if endpoint == nil || endpoint.IP == nil {
log.Errorf("failed to start package redirection, endpoint is nil")
return
}
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
if endpoint != nil && endpoint.IP != nil {
p.wgEndpointCurrentUsedAddr = endpoint
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()

View File

@@ -54,14 +54,6 @@ func (w *KernelFactory) GetProxy() Proxy {
return ebpf.NewProxyWrapper(w.ebpfProxy)
}
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
func (w *KernelFactory) GetProxyPort() uint16 {
if w.ebpfProxy == nil {
return 0
}
return w.ebpfProxy.GetProxyPort()
}
func (w *KernelFactory) Free() error {
if w.ebpfProxy == nil {
return nil

View File

@@ -24,11 +24,6 @@ func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind, w.mtu)
}
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
func (w *USPFactory) GetProxyPort() uint16 {
return 0
}
func (w *USPFactory) Free() error {
return nil
}

View File

@@ -8,87 +8,43 @@ import (
"os"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nbnet "github.com/netbirdio/netbird/client/net"
)
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET, true)
}
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET6, false)
}
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
func PrepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
if isIPv4 {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
} else {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file: %v", closeErr)
}
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
// Close the original file to release the FD (net.FilePacketConn duplicates it)
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
}
return packetConn, nil
}

View File

@@ -1,353 +0,0 @@
//go:build linux && !android
package wgproxy
import (
"context"
"net"
"testing"
"time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
func compareUDPAddr(addr1, addr2 net.Addr) bool {
udpAddr1, ok1 := addr1.(*net.UDPAddr)
udpAddr2, ok2 := addr2.(*net.UDPAddr)
if !ok1 || !ok2 {
return addr1.String() == addr2.String()
}
// Compare IP and Port, ignoring zone
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
func TestRedirectAs_UDP_IPv4(t *testing.T) {
wgPort := 51852
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
func TestRedirectAs_UDP_IPv6(t *testing.T) {
wgPort := 51853
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// testRedirectAs is a helper function that tests the RedirectAs functionality
// It verifies that:
// 1. Initial traffic from relay connection works
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
// 3. Multiple packets are correctly redirected with the new source address
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
t.Helper()
ctx := context.Background()
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
}
defer wgListener4.Close()
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
IP: net.ParseIP("::1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
}
defer wgListener6.Close()
// Determine which listener to use based on the NetBird address IP version
// (this is where initial traffic will come from before RedirectAs is called)
var wgListener *net.UDPConn
if p2pEndpoint.IP.To4() == nil {
wgListener = wgListener6
} else {
wgListener = wgListener4
}
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0, // Random port
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
// Add TURN connection to proxy
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
// Start the proxy
proxy.Work()
// Phase 1: Test initial relay traffic
msgFromRelay := []byte("hello from relay")
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write to relay server: %v", err)
}
// Set read deadline to avoid hanging
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
buf := make([]byte, 1024)
n, _, err := wgListener4.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read from WireGuard listener: %v", err)
}
if n != len(msgFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
}
if string(buf[:n]) != string(msgFromRelay) {
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
}
// Phase 2: Redirect to p2p endpoint
proxy.RedirectAs(p2pEndpoint)
// Give the proxy a moment to process the redirect
time.Sleep(100 * time.Millisecond)
// Phase 3: Test redirected traffic
redirectedMessages := [][]byte{
[]byte("redirected message 1"),
[]byte("redirected message 2"),
[]byte("redirected message 3"),
}
for i, msg := range redirectedMessages {
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
}
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
}
// Verify message content
if string(buf[:n]) != string(msg) {
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
}
// Verify source address matches p2p endpoint (this is the key test)
// Use compareUDPAddr to ignore IPv6 zone IDs
if !compareUDPAddr(srcAddr, p2pEndpoint) {
t.Errorf("message %d: expected source address %s, got %s",
i+1, p2pEndpoint.String(), srcAddr.String())
}
}
}
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
func TestRedirectAs_Multiple_Switches(t *testing.T) {
wgPort := 51856
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
ctx := context.Background()
// Create WireGuard listener
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create WireGuard listener: %v", err)
}
defer wgListener.Close()
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
proxy.Work()
// Test switching between multiple endpoints - using addresses in local subnet
endpoints := []*net.UDPAddr{
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
}
for i, endpoint := range endpoints {
proxy.RedirectAs(endpoint)
time.Sleep(100 * time.Millisecond)
msg := []byte("test message")
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
}
buf := make([]byte, 1024)
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
}
if string(buf[:n]) != string(msg) {
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
}
if !compareUDPAddr(srcAddr, endpoint) {
t.Errorf("endpoint %d: expected source %s, got %s",
i, endpoint.String(), srcAddr.String())
}
}
}

View File

@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
dialer := net.Dialer{}
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {

View File

@@ -19,56 +19,37 @@ var (
FixLengths: true,
}
localHostNetIPAddrV4 = &net.IPAddr{
localHostNetIPAddr = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
localHostNetIPAddrV6 = &net.IPAddr{
IP: net.ParseIP("::1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
localHostAddr *net.IPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
// Create only the raw socket for the address family we need
var rawSocket net.PacketConn
var err error
var localHostAddr *net.IPAddr
if srcAddr.IP.To4() != nil {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
localHostAddr = localHostNetIPAddrV4
} else {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
localHostAddr = localHostNetIPAddrV6
}
rawSocket, err := rawsocket.PrepareSenderRawSocket()
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
if closeErr := rawSocket.Close(); closeErr != nil {
log.Warnf("failed to close raw socket: %v", closeErr)
}
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
}
return f, nil
@@ -91,7 +72,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
@@ -99,40 +80,19 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
// Check if source IP is IPv4 or IPv6
if srcAddr.IP.To4() != nil {
// IPv4
ipv4 := &layers.IPv4{
DstIP: localHostNetIPAddrV4.IP,
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
} else {
// IPv6
ipv6 := &layers.IPv6{
DstIP: localHostNetIPAddrV6.IP,
SrcIP: srcAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(networkLayer)
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}

View File

@@ -1,499 +0,0 @@
package auth
import (
"context"
"net/url"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Auth manages authentication operations with the management server
// It maintains a long-lived connection and automatically handles reconnection with backoff
type Auth struct {
mutex sync.RWMutex
client *mgm.GrpcClient
config *profilemanager.Config
privateKey wgtypes.Key
mgmURL *url.URL
mgmTLSEnabled bool
}
// NewAuth creates a new Auth instance that manages authentication flows
// It establishes a connection to the management server that will be reused for all operations
// The connection is automatically recreated with backoff if it becomes disconnected
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
// Validate WireGuard private key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
return nil, err
}
// Determine TLS setting based on URL scheme
mgmTLSEnabled := mgmURL.Scheme == "https"
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
return nil, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
return &Auth{
client: mgmClient,
config: config,
privateKey: myPrivateKey,
mgmURL: mgmURL,
mgmTLSEnabled: mgmTLSEnabled,
}, nil
}
// Close closes the management client connection
func (a *Auth) Close() error {
a.mutex.Lock()
defer a.mutex.Unlock()
if a.client == nil {
return nil
}
return a.client.Close()
}
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
var supportsSSO bool
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
// Try PKCE flow first
_, err := a.getPKCEFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if PKCE is not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// PKCE not supported, try Device flow
_, err = a.getDeviceFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if Device flow is also not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// Neither PKCE nor Device flow is supported
supportsSSO = false
return nil
}
// Device flow check returned an error other than NotFound/Unimplemented
return err
}
// PKCE flow check returned an error other than NotFound/Unimplemented
return err
})
return supportsSSO, err
}
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
// This avoids creating a new connection to the management server
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
var flow OAuthFlow
var err error
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
if forceDeviceAuth {
flow, err = a.getDeviceFlow(client)
return err
}
// Try PKCE flow first
flow, err = a.getPKCEFlow(client)
if err != nil {
// If PKCE not supported, try Device flow
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
flow, err = a.getDeviceFlow(client)
return err
}
return err
}
return nil
})
return flow, err
}
// IsLoginRequired checks if login is required by attempting to authenticate with the server
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return false, err
}
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
}
needsLogin = false
return err
})
return needsLogin, err
}
// Login attempts to log in or register the client with the management server
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return err, false
}
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
} else if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
isAuthError = false
return nil
})
return err, isAuthError
}
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
RedirectURLs: protoConfig.GetRedirectURLs(),
UseIDToken: protoConfig.GetUseIDToken(),
ClientCertPair: a.config.ClientCertKeyPair,
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
}
if err := validatePKCEConfig(config); err != nil {
return nil, err
}
flow, err := NewPKCEAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
Scope: protoConfig.GetScope(),
UseIDToken: protoConfig.GetUseIDToken(),
}
// Keep compatibility with older management versions
if config.Scope == "" {
config.Scope = "openid"
}
if err := validateDeviceAuthConfig(config); err != nil {
return nil, err
}
flow, err := NewDeviceAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
// setSystemInfoFlags sets all configuration flags on the provided system info
func (a *Auth) setSystemInfoFlags(info *system.Info) {
info.SetFlags(
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
a.config.DisableFirewall,
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
)
}
// reconnect closes the current connection and creates a new one
// It checks if the brokenClient is still the current client before reconnecting
// to avoid multiple threads reconnecting unnecessarily
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
a.mutex.Lock()
defer a.mutex.Unlock()
// Double-check: if client has already been replaced by another thread, skip reconnection
if a.client != brokenClient {
log.Debugf("client already reconnected by another thread, skipping")
return nil
}
// Create new connection FIRST, before closing the old one
// This ensures a.client is never nil, preventing panics in other threads
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
if err != nil {
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
// Keep the old client if reconnection fails
return err
}
// Close old connection AFTER new one is successfully created
oldClient := a.client
a.client = mgmClient
if oldClient != nil {
if err := oldClient.Close(); err != nil {
log.Debugf("error closing old connection: %v", err)
}
}
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
return nil
}
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
func isConnectionError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
// These error codes indicate connection issues
return s.Code() == codes.Unavailable ||
s.Code() == codes.DeadlineExceeded ||
s.Code() == codes.Canceled ||
s.Code() == codes.Internal
}
// withRetry wraps an operation with exponential backoff retry logic
// It automatically reconnects on connection errors
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
backoffSettings := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.5,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 2 * time.Minute,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
backoffSettings.Reset()
return backoff.RetryNotify(
func() error {
// Capture the client BEFORE the operation to ensure we track the correct client
a.mutex.RLock()
currentClient := a.client
a.mutex.RUnlock()
if currentClient == nil {
return status.Errorf(codes.Unavailable, "client is not initialized")
}
// Execute operation with the captured client
err := operation(currentClient)
if err == nil {
return nil
}
// If it's a connection error, attempt reconnection using the client that was actually used
if isConnectionError(err) {
log.Warnf("connection error detected, attempting reconnection: %v", err)
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
log.Errorf("reconnection failed: %v", reconnectErr)
return reconnectErr
}
// Return the original error to trigger retry with the new connection
return err
}
// For authentication errors, don't retry
if isAuthenticationError(err) {
return backoff.Permanent(err)
}
return err
},
backoff.WithContext(backoffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("operation failed, retrying in %v: %v", duration, err)
},
)
}
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
func isAuthenticationError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
}
// isPermissionDenied checks if the error is a PermissionDenied error.
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
func isPermissionDenied(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.PermissionDenied
}
func isLoginNeeded(err error) bool {
return isAuthenticationError(err)
}
func isRegistrationNeeded(err error) bool {
return isPermissionDenied(err)
}

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/embeddedroots"
)
@@ -25,56 +26,12 @@ const (
var _ OAuthFlow = &DeviceAuthorizationFlow{}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validateDeviceAuthConfig validates device authorization provider configuration
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMsgFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
}
return nil
}
// DeviceAuthorizationFlow implements the OAuthFlow interface,
// for the Device Authorization Flow.
type DeviceAuthorizationFlow struct {
providerConfig DeviceAuthProviderConfig
HTTPClient HTTPClient
providerConfig internal.DeviceAuthProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
@@ -100,7 +57,7 @@ type TokenRequestResponse struct {
}
// NewDeviceAuthorizationFlow returns device authorization flow client
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -132,11 +89,6 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
return d.providerConfig.ClientID
}
// SetLoginHint sets the login hint for the device authorization flow
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
d.providerConfig.LoginHint = hint
}
// RequestAuthInfo requests a device code login flow information from Hosted
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
form := url.Values{}
@@ -247,22 +199,14 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
// it retrieves the access token from Hosted's endpoint and validates it before returning
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-ticker.C:
tokenResponse, err := d.requestToken(info)

View File

@@ -12,6 +12,8 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockHTTPClient struct {
@@ -113,19 +115,18 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
err: testCase.inputReqError,
}
config := DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
deviceFlow := &DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
},
HTTPClient: &httpClient,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
@@ -279,19 +280,18 @@ func TestHosted_WaitToken(t *testing.T) {
countResBody: testCase.inputCountResBody,
}
config := DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
deviceFlow := DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
},
HTTPClient: &httpClient,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
defer cancel()
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)

View File

@@ -10,6 +10,7 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -86,33 +87,19 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
if hint != "" {
pkceFlowInfo.SetLoginHint(hint)
}
pkceFlowInfo.ProviderConfig.LoginHint = hint
return pkceFlowInfo, nil
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
switch s, ok := gstatus.FromError(err); {
case ok && s.Code() == codes.NotFound:
@@ -127,9 +114,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
}
}
if hint != "" {
deviceFlowInfo.SetLoginHint(hint)
}
deviceFlowInfo.ProviderConfig.LoginHint = hint
return deviceFlowInfo, nil
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
}

View File

@@ -20,6 +20,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
"github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -34,67 +35,17 @@ const (
defaultPKCETimeoutSeconds = 300
)
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validatePKCEConfig validates PKCE provider configuration
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
}
return nil
}
// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
providerConfig PKCEAuthProviderConfig
providerConfig internal.PKCEAuthProviderConfig
state string
codeVerifier string
oAuthConfig *oauth2.Config
}
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
excludedRanges := getSystemExcludedPortRanges()
@@ -173,21 +124,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
}, nil
}
// SetLoginHint sets the login hint for the PKCE authorization flow
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
p.providerConfig.LoginHint = hint
}
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
@@ -198,7 +138,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
@@ -209,8 +149,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
go p.startServer(server, tokenChan, errChan)
select {
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case token := <-tokenChan:
return p.parseOAuthToken(token)
case err := <-errChan:

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -49,7 +50,7 @@ func TestPromptLogin(t *testing.T) {
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := PKCEAuthProviderConfig{
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -9,6 +9,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
func TestParseExcludedPortRanges(t *testing.T) {
@@ -93,7 +95,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
availablePort := 65432
config := PKCEAuthProviderConfig{
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -0,0 +1,136 @@
package internal
import (
"context"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// DeviceAuthorizationFlow represents Device Authorization Flow information
type DeviceAuthorizationFlow struct {
Provider string
ProviderConfig DeviceAuthProviderConfig
}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return DeviceAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return DeviceAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return DeviceAuthorizationFlow{}, err
}
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return DeviceAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return DeviceAuthorizationFlow{}, err
}
deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: DeviceAuthProviderConfig{
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
},
}
// keep compatibility with older management versions
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
}
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
}
return deviceAuthorizationFlow, nil
}
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
}
return nil
}

View File

@@ -505,10 +505,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("up wg interface: %w", err)
}
// Set up notrack rules immediately after proxy is listening to prevent
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
@@ -621,23 +617,6 @@ func (e *Engine) initFirewall() error {
return nil
}
// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic.
// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy.
func (e *Engine) setupWGProxyNoTrack() {
if e.firewall == nil {
return
}
proxyPort := e.wgInterface.GetProxyPort()
if proxyPort == 0 {
return
}
if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil {
log.Warnf("failed to setup ebpf proxy notrack: %v", err)
}
}
func (e *Engine) blockLanAccess() {
if e.config.BlockInbound {
// no need to set up extra deny rules if inbound is already blocked in general
@@ -1665,7 +1644,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
func (e *Engine) close() {
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)

View File

@@ -107,7 +107,6 @@ type MockWGIface struct {
GetStatsFunc func() (map[string]configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
GetProxyPortFunc func() uint16
GetNetFunc func() *netstack.Net
LastActivitiesFunc func() map[string]monotime.Time
}
@@ -204,13 +203,6 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy {
return m.GetProxyFunc()
}
func (m *MockWGIface) GetProxyPort() uint16 {
if m.GetProxyPortFunc != nil {
return m.GetProxyPortFunc()
}
return 0
}
func (m *MockWGIface) GetNet() *netstack.Net {
return m.GetNetFunc()
}

View File

@@ -28,7 +28,6 @@ type wgIfaceBase interface {
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
GetProxyPort() uint16
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemoveEndpointAddress(key string) error
RemovePeer(peerKey string) error

201
client/internal/login.go Normal file
View File

@@ -0,0 +1,201 @@
package internal
import (
"context"
"net/url"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// IsLoginRequired check that the server is support SSO or not
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
mgmURL := config.ManagementURL
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
if err != nil {
return false, err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", mgmURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return false, err
}
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) {
return true, nil
}
return false, err
}
// Login or register the client
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return err
}
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
if err != nil {
return err
}
} else if err != nil {
return err
}
return nil
}
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return nil, err
}
var mgmTlsEnabled bool
if mgmURL.Scheme == "https" {
mgmTlsEnabled = true
}
log.Debugf("connecting to the Management service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
return nil, err
}
return mgmClient, err
}
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
info.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
func isLoginNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
return true
}
return false
}
func isRegistrationNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.PermissionDenied {
return true
}
return false
}

View File

@@ -0,0 +1,138 @@
package internal
import (
"context"
"crypto/tls"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
)
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
type PKCEAuthorizationFlow struct {
ProviderConfig PKCEAuthProviderConfig
}
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return PKCEAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return PKCEAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return PKCEAuthorizationFlow{}, err
}
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return PKCEAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return PKCEAuthorizationFlow{}, err
}
authFlow := PKCEAuthorizationFlow{
ProviderConfig: PKCEAuthProviderConfig{
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
},
}
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
if err != nil {
return PKCEAuthorizationFlow{}, err
}
return authFlow, nil
}
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
}
return nil
}

View File

@@ -263,14 +263,7 @@ func (c *Client) IsLoginRequired() bool {
return true
}
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
log.Errorf("IsLoginRequired: failed to create auth client: %v", err)
return true // Assume login is required if we can't create auth client
}
defer authClient.Close()
needsLogin, err := authClient.IsLoginRequired(ctx)
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
if err != nil {
log.Errorf("IsLoginRequired: check failed: %v", err)
// If the check fails, assume login is required to be safe
@@ -321,19 +314,16 @@ func (c *Client) LoginForMobile() string {
// This could cause a potential race condition with loading the extension which need to be handled on swift side
go func() {
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
return
}
jwtToken := tokenInfo.GetTokenToUse()
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
log.Errorf("LoginForMobile: failed to create auth client: %v", err)
return
}
defer authClient.Close()
if err, _ := authClient.Login(ctx, "", jwtToken); err != nil {
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
log.Errorf("LoginForMobile: Login failed: %v", err)
return
}

View File

@@ -7,8 +7,13 @@ import (
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -85,21 +90,34 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
return err
}
return err
})
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
@@ -123,17 +141,19 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err, _ = authClient.Login(ctxWithValues, setupKey, "")
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
@@ -144,16 +164,15 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
// LoginSync performs a synchronous login check without UI interaction
// Used for background VPN connection where user should already be authenticated
func (a *Auth) LoginSync() error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
var needsLogin bool
// check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(a.ctx)
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
@@ -161,12 +180,15 @@ func (a *Auth) LoginSync() error {
return fmt.Errorf("not authenticated")
}
err, isAuthError := authClient.Login(a.ctx, "", jwtToken)
if err != nil {
if isAuthError {
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return fmt.Errorf("authentication error: %v", err)
return backoff.Permanent(err)
}
return err
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
@@ -203,6 +225,8 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen
}
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
var needsLogin bool
// Create context with device name if provided
ctx := a.ctx
if deviceName != "" {
@@ -210,33 +234,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
}
authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(ctx)
err := a.withBackOff(ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
return
})
if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth)
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err, isAuthError := authClient.Login(ctx, "", jwtToken)
if err != nil {
if isAuthError {
err = a.withBackOff(ctx, func() error {
err := internal.Login(ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return fmt.Errorf("authentication error: %v", err)
return backoff.Permanent(err)
}
return err
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
@@ -261,10 +285,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
const authInfoRequestTimeout = 30 * time.Second
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth)
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
if err != nil {
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
return nil, err
}
// Use a bounded timeout for the auth info request to prevent indefinite hangs
@@ -289,6 +313,15 @@ func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}
// GetConfigJSON returns the current config as a JSON string.
// This can be used by the caller to persist the config via alternative storage
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).

View File

@@ -253,17 +253,10 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
if err != nil {
log.Errorf("failed to create auth client: %v", err)
return internal.StatusLoginFailed, err
}
defer authClient.Close()
var status internal.StatusType
err, isAuthError := authClient.Login(ctx, setupKey, jwtToken)
err := internal.Login(ctx, s.config, setupKey, jwtToken)
if err != nil {
if isAuthError {
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
log.Warnf("failed login: %v", err)
status = internal.StatusNeedsLogin
} else {
@@ -588,7 +581,8 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.oauthAuthFlow.waitCancel()
}
waitCTX, cancel := context.WithCancel(ctx)
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
s.mutex.Lock()

View File

@@ -207,6 +207,8 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
}
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
// Create a backend session to mirror the client's session request.
// This keeps the connection alive on the server side while port forwarding channels operate.
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
@@ -214,28 +216,10 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c
}
defer func() { _ = serverSession.Close() }()
serverSession.Stdin = session
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
<-session.Context().Done()
if err := serverSession.Shell(); err != nil {
log.Debugf("start shell: %v", err)
return
}
done := make(chan error, 1)
go func() {
done <- serverSession.Wait()
}()
select {
case <-session.Context().Done():
return
case err := <-done:
if err != nil {
log.Debugf("shell session: %v", err)
p.handleProxyExitCode(session, err)
}
if err := session.Exit(0); err != nil {
log.Debugf("session exit: %v", err)
}
}

View File

@@ -12,8 +12,8 @@ import (
log "github.com/sirupsen/logrus"
)
// handleExecution executes an SSH command or shell with privilege validation
func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
// handleCommand executes an SSH command with privilege validation
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
hasPty := winCh != nil
commandType := "command"
@@ -23,7 +23,7 @@ func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privile
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty)
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
if err != nil {
logger.Errorf("%s creation failed: %v", commandType, err)
@@ -51,12 +51,13 @@ func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privile
defer cleanup()
ptyReq, _, _ := session.Pty()
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
logger.Debugf("%s execution completed", commandType)
}
}
func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
localUser := privilegeResult.User
if localUser == nil {
return nil, nil, errors.New("no user in privilege result")
@@ -65,28 +66,28 @@ func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheck
// If PTY requested but su doesn't support --pty, skip su and use executor
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
if hasPty && !s.suSupportsPty {
logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, cleanup, nil
}
// Try su first for system integration (PAM/audit) when privileged
cmd, err := s.createSuCommand(logger, session, localUser, hasPty)
cmd, err := s.createSuCommand(session, localUser, hasPty)
if err != nil || privilegeResult.UsedFallback {
logger.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
log.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, cleanup, nil
}
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, func() {}, nil
}

View File

@@ -15,17 +15,17 @@ import (
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
// createSuCommand is not supported on JS/WASM
func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
return nil, errNotSupported
}
// createExecutorCommand is not supported on JS/WASM
func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
return nil, nil, errNotSupported
}
// prepareCommandEnv is not supported on JS/WASM
func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string {
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
return nil
}

View File

@@ -10,7 +10,6 @@ import (
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strings"
"sync"
@@ -100,52 +99,40 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
return isUtilLinux
}
// createSuCommand creates a command using su - for privilege switching.
func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
if err := validateUsername(localUser.Username); err != nil {
return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
}
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su")
if err != nil {
return nil, fmt.Errorf("su command not available: %w", err)
}
args := []string{"-"}
command := session.RawCommand()
if command == "" {
return nil, fmt.Errorf("no command specified for su execution")
}
args := []string{"-l"}
if hasPty && s.suSupportsPty {
args = append(args, "--pty")
}
args = append(args, localUser.Username)
args = append(args, localUser.Username, "-c", command)
command := session.RawCommand()
if command != "" {
args = append(args, "-c", command)
}
logger.Debugf("creating su command: %s %v", suPath, args)
cmd := exec.CommandContext(session.Context(), suPath, args...)
cmd.Dir = localUser.HomeDir
return cmd, nil
}
// getShellCommandArgs returns the shell command and arguments for executing a command string.
// getShellCommandArgs returns the shell command and arguments for executing a command string
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" {
return []string{shell}
return []string{shell, "-l"}
}
return []string{shell, "-c", cmdString}
}
// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname".
func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd {
cmd := exec.CommandContext(ctx, shell, args[1:]...)
cmd.Args[0] = "-" + filepath.Base(shell)
return cmd
return []string{shell, "-l", "-c", cmdString}
}
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
@@ -167,7 +154,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
if err != nil {
logger.Errorf("Pty command creation failed: %v", err)
@@ -257,6 +244,11 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty
}()
go func() {
defer func() {
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
logger.Debugf("session close error: %v", err)
}
}()
if _, err := io.Copy(session, ptmx); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
logger.Warnf("Pty output copy error: %v", err)
@@ -276,7 +268,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex
case <-ctx.Done():
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
case err := <-done:
s.handlePtyCommandCompletion(logger, session, ptyMgr, err)
s.handlePtyCommandCompletion(logger, session, err)
}
}
@@ -304,20 +296,17 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses
}
}
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) {
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
if err != nil {
logger.Debugf("Pty command execution failed: %v", err)
s.handleSessionExit(session, err, logger)
} else {
logger.Debugf("Pty command completed successfully")
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
return
}
// Close PTY to unblock io.Copy goroutines
if err := ptyMgr.Close(); err != nil {
logger.Debugf("Pty close after completion: %v", err)
// Normal completion
logger.Debugf("Pty command completed successfully")
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
}

View File

@@ -20,32 +20,32 @@ import (
// getUserEnvironment retrieves the Windows environment for the target user.
// Follows OpenSSH's resilient approach with graceful degradation on failures.
func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) {
userToken, err := s.getUserToken(logger, username, domain)
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
userToken, err := s.getUserToken(username, domain)
if err != nil {
return nil, fmt.Errorf("get user token: %w", err)
}
defer func() {
if err := windows.CloseHandle(userToken); err != nil {
logger.Debugf("close user token: %v", err)
log.Debugf("close user token: %v", err)
}
}()
return s.getUserEnvironmentWithToken(logger, userToken, username, domain)
return s.getUserEnvironmentWithToken(userToken, username, domain)
}
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) {
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
userProfile, err := s.loadUserProfile(userToken, username, domain)
if err != nil {
logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
}
envMap := make(map[string]string)
if err := s.loadSystemEnvironment(envMap); err != nil {
logger.Debugf("failed to load system environment from registry: %v", err)
log.Debugf("failed to load system environment from registry: %v", err)
}
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
@@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken window
}
// getUserToken creates a user token for the specified user.
func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
privilegeDropper := NewPrivilegeDropper()
token, err := privilegeDropper.createToken(username, domain)
if err != nil {
return 0, fmt.Errorf("generate S4U user token: %w", err)
@@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi
}
// prepareCommandEnv prepares environment variables for command execution on Windows
func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string {
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
username, domain := s.parseUsername(localUser.Username)
userEnv, err := s.getUserEnvironment(logger, username, domain)
userEnv, err := s.getUserEnvironment(username, domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
@@ -267,16 +267,22 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
return env
}
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
if privilegeResult.User == nil {
logger.Errorf("no user in privilege result")
return false
}
cmd := session.Command()
shell := getUserShell(privilegeResult.User.Uid)
logger.Infof("starting interactive shell: %s", shell)
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
if len(cmd) == 0 {
logger.Infof("starting interactive shell: %s", shell)
} else {
logger.Infof("executing command: %s", safeLogCommand(cmd))
}
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
return true
}
@@ -288,6 +294,11 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
return []string{shell, "-Command", cmdString}
}
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
logger.Info("starting interactive shell")
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
}
type PtyExecutionRequest struct {
Shell string
Command string
@@ -297,25 +308,25 @@ type PtyExecutionRequest struct {
Domain string
}
func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error {
logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
privilegeDropper := NewPrivilegeDropper()
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
if err != nil {
return fmt.Errorf("create user token: %w", err)
}
defer func() {
if err := windows.CloseHandle(userToken); err != nil {
logger.Debugf("close user token: %v", err)
log.Debugf("close user token: %v", err)
}
}()
server := &Server{}
userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain)
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
if err != nil {
logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
userEnv = os.Environ()
}
@@ -337,8 +348,8 @@ func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req
Environment: userEnv,
}
logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig)
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
}
func getUserHomeFromEnv(env []string) string {
@@ -360,8 +371,10 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
return
}
logger := log.WithField("pid", cmd.Process.Pid)
if err := cmd.Process.Kill(); err != nil {
log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err)
logger.Debugf("kill process failed: %v", err)
}
}
@@ -376,7 +389,21 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool {
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand()
if command == "" {
logger.Error("no command specified for PTY execution")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
}
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
localUser := privilegeResult.User
if localUser == nil {
logger.Errorf("no user in privilege result")
@@ -388,14 +415,14 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _
req := PtyExecutionRequest{
Shell: shell,
Command: session.RawCommand(),
Command: command,
Width: ptyReq.Window.Width,
Height: ptyReq.Window.Height,
Username: username,
Domain: domain,
}
if err := executePtyCommandWithUserToken(logger, session, req); err != nil {
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
logger.Errorf("ConPty execution failed: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)

View File

@@ -4,15 +4,12 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
"time"
@@ -26,67 +23,25 @@ import (
"github.com/netbirdio/netbird/client/ssh/testutil"
)
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) {
// On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server
// spawns an executor subprocess via os.Executable(). During tests, this invokes the test
// binary with "ssh exec" args. We handle that here to properly execute commands and
// propagate exit codes.
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
runTestExecutor()
return
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests
code := m.Run()
// Cleanup any created test users
testutil.CleanupTestUsers()
os.Exit(code)
}
// runTestExecutor emulates the netbird executor for tests.
// Parses --shell and --cmd args, runs the command, and exits with the correct code.
func runTestExecutor() {
if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" {
fmt.Fprintf(os.Stderr, "executor recursion detected\n")
os.Exit(1)
}
os.Setenv("_NETBIRD_TEST_EXECUTOR", "1")
shell := "/bin/sh"
var command string
for i := 3; i < len(os.Args); i++ {
switch os.Args[i] {
case "--shell":
if i+1 < len(os.Args) {
shell = os.Args[i+1]
i++
}
case "--cmd":
if i+1 < len(os.Args) {
command = os.Args[i+1]
i++
}
}
}
var cmd *exec.Cmd
if command == "" {
cmd = exec.Command(shell)
} else {
cmd = exec.Command(shell, "-c", command)
}
cmd.Args[0] = "-" + filepath.Base(shell)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
os.Exit(1)
}
os.Exit(0)
}
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
func TestSSHServerCompatibility(t *testing.T) {
if testing.Short() {
@@ -450,171 +405,6 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
return createTempKeyFileFromBytes(t, privateKey)
}
// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags)
// This ensures our implementation matches OpenSSH behavior for:
// - ssh host command (no PTY - default when no TTY)
// - ssh -T host command (explicit no PTY)
// - ssh -t host command (force PTY)
// - ssh -T host (no PTY shell - our implementation)
func TestSSHPtyModes(t *testing.T) {
if testing.Short() {
t.Skip("Skipping SSH PTY mode tests in short mode")
}
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues")
}
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
defer cleanupKey()
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
username := testutil.GetTestUsername(t)
baseArgs := []string{
"-i", clientKeyFile,
"-p", portStr,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
"-o", "BatchMode=yes",
}
t.Run("command_default_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (default no PTY) failed: %s", output)
assert.Contains(t, string(output), "no_pty_default")
})
t.Run("command_explicit_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output)
assert.Contains(t, string(output), "explicit_no_pty")
})
t.Run("command_force_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (-tt force PTY) failed: %s", output)
assert.Contains(t, string(output), "force_pty")
})
t.Run("shell_explicit_no_pty", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host))
cmd := exec.CommandContext(ctx, "ssh", args...)
stdin, err := cmd.StdinPipe()
require.NoError(t, err)
stdout, err := cmd.StdoutPipe()
require.NoError(t, err)
require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed")
go func() {
defer stdin.Close()
time.Sleep(100 * time.Millisecond)
_, err := stdin.Write([]byte("echo shell_no_pty_test\n"))
assert.NoError(t, err, "write echo command")
time.Sleep(100 * time.Millisecond)
_, err = stdin.Write([]byte("exit 0\n"))
assert.NoError(t, err, "write exit command")
}()
output, _ := io.ReadAll(stdout)
err = cmd.Wait()
require.NoError(t, err, "Shell (-T no PTY) failed: %s", output)
assert.Contains(t, string(output), "shell_no_pty_test")
})
t.Run("exit_code_preserved_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42")
cmd := exec.Command("ssh", args...)
err := cmd.Run()
require.Error(t, err, "Command should exit with non-zero")
var exitErr *exec.ExitError
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T")
})
t.Run("exit_code_preserved_with_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'")
cmd := exec.Command("ssh", args...)
err := cmd.Run()
require.Error(t, err, "PTY command should exit with non-zero")
var exitErr *exec.ExitError
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt")
})
t.Run("stderr_works_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host),
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
cmd := exec.Command("ssh", args...)
var stdout, stderr strings.Builder
cmd.Stdout = &stdout
cmd.Stderr = &stderr
require.NoError(t, cmd.Run(), "stderr test failed")
assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg")
assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg")
assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg")
})
t.Run("stderr_merged_with_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host),
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "PTY stderr test failed: %s", output)
assert.Contains(t, string(output), "stdout_msg")
assert.Contains(t, string(output), "stderr_msg")
})
}
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
func TestSSHServerFeatureCompatibility(t *testing.T) {
if testing.Short() {

View File

@@ -8,7 +8,6 @@ import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"syscall"
@@ -36,35 +35,11 @@ type ExecutorConfig struct {
}
// PrivilegeDropper handles secure privilege dropping in child processes
type PrivilegeDropper struct {
logger *log.Entry
}
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
type PrivilegeDropperOption func(*PrivilegeDropper)
type PrivilegeDropper struct{}
// NewPrivilegeDropper creates a new privilege dropper
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
pd := &PrivilegeDropper{}
for _, opt := range opts {
opt(pd)
}
return pd
}
// WithLogger sets the logger for the PrivilegeDropper
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
return func(pd *PrivilegeDropper) {
pd.logger = logger
}
}
// log returns the logger, falling back to standard logger if none set
func (pd *PrivilegeDropper) log() *log.Entry {
if pd.logger != nil {
return pd.logger
}
return log.NewEntry(log.StandardLogger())
func NewPrivilegeDropper() *PrivilegeDropper {
return &PrivilegeDropper{}
}
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
@@ -108,7 +83,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex
break
}
}
pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
return exec.CommandContext(ctx, netbirdPath, args...), nil
}
@@ -231,22 +206,17 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config
var execCmd *exec.Cmd
if config.Command == "" {
execCmd = exec.CommandContext(ctx, config.Shell)
} else {
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
os.Exit(ExitCodeSuccess)
}
execCmd.Args[0] = "-" + filepath.Base(config.Shell)
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
execCmd.Stdin = os.Stdin
execCmd.Stdout = os.Stdout
execCmd.Stderr = os.Stderr
if config.Command == "" {
log.Tracef("executing login shell: %s", execCmd.Path)
} else {
cmdParts := strings.Fields(config.Command)
safeCmd := safeLogCommand(cmdParts)
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
}
cmdParts := strings.Fields(config.Command)
safeCmd := safeLogCommand(cmdParts)
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
if err := execCmd.Run(); err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {

View File

@@ -28,45 +28,22 @@ const (
)
type WindowsExecutorConfig struct {
Username string
Domain string
WorkingDir string
Shell string
Command string
Args []string
Pty bool
PtyWidth int
PtyHeight int
Username string
Domain string
WorkingDir string
Shell string
Command string
Args []string
Interactive bool
Pty bool
PtyWidth int
PtyHeight int
}
type PrivilegeDropper struct {
logger *log.Entry
}
type PrivilegeDropper struct{}
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
type PrivilegeDropperOption func(*PrivilegeDropper)
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
pd := &PrivilegeDropper{}
for _, opt := range opts {
opt(pd)
}
return pd
}
// WithLogger sets the logger for the PrivilegeDropper
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
return func(pd *PrivilegeDropper) {
pd.logger = logger
}
}
// log returns the logger, falling back to standard logger if none set
func (pd *PrivilegeDropper) log() *log.Entry {
if pd.logger != nil {
return pd.logger
}
return log.NewEntry(log.StandardLogger())
func NewPrivilegeDropper() *PrivilegeDropper {
return &PrivilegeDropper{}
}
var (
@@ -79,6 +56,7 @@ const (
// Common error messages
commandFlag = "-Command"
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
convertUsernameError = "convert username to UTF16: %w"
convertDomainError = "convert domain to UTF16: %w"
)
@@ -102,7 +80,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
shellArgs = []string{shell}
}
pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
cmd, token, err := pd.CreateWindowsProcessAsUser(
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
@@ -202,10 +180,10 @@ func newLsaString(s string) lsaString {
// generateS4UUserToken creates a Windows token using S4U authentication
// This is the exact approach OpenSSH for Windows uses for public key authentication
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain)
pd := NewPrivilegeDropper(WithLogger(logger))
pd := NewPrivilegeDropper()
isDomainUser := !pd.isLocalUser(domain)
lsaHandle, err := initializeLsaConnection()
@@ -219,12 +197,12 @@ func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.H
return 0, err
}
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser)
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
if err != nil {
return 0, err
}
return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
}
// buildUserCpn constructs the user principal name
@@ -332,21 +310,21 @@ func lookupPrincipalName(username, domain string) (string, error) {
}
// prepareS4ULogonStructure creates the appropriate S4U logon structure
func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
if isDomainUser {
return prepareDomainS4ULogon(logger, username, domain)
return prepareDomainS4ULogon(username, domain)
}
return prepareLocalS4ULogon(logger, username)
return prepareLocalS4ULogon(username)
}
// prepareDomainS4ULogon creates S4U logon structure for domain users
func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) {
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
upn, err := lookupPrincipalName(username, domain)
if err != nil {
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
}
logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
upnUtf16, err := windows.UTF16FromString(upn)
if err != nil {
@@ -379,8 +357,8 @@ func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.P
}
// prepareLocalS4ULogon creates S4U logon structure for local users
func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) {
logger.Debugf("using Msv1_0S4ULogon for local user: %s", username)
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
usernameUtf16, err := windows.UTF16FromString(username)
if err != nil {
@@ -428,11 +406,11 @@ func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, u
}
// performS4ULogon executes the S4U logon operation
func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
var tokenSource tokenSource
copy(tokenSource.SourceName[:], "netbird")
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
logger.Debugf("AllocateLocallyUniqueId failed")
log.Debugf("AllocateLocallyUniqueId failed")
}
originName := newLsaString("netbird")
@@ -463,7 +441,7 @@ func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId
if profile != 0 {
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
}
}
@@ -471,7 +449,7 @@ func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
}
logger.Debugf("created S4U %s token for user %s",
log.Debugf("created S4U %s token for user %s",
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
return token, nil
}
@@ -519,8 +497,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
// authenticateLocalUser handles authentication for local users
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
pd.log().Debugf("using S4U authentication for local user %s", fullUsername)
token, err := generateS4UUserToken(pd.log(), username, ".")
log.Debugf("using S4U authentication for local user %s", fullUsername)
token, err := generateS4UUserToken(username, ".")
if err != nil {
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
}
@@ -529,12 +507,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string)
// authenticateDomainUser handles authentication for domain users
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
pd.log().Debugf("using S4U authentication for domain user %s", fullUsername)
token, err := generateS4UUserToken(pd.log(), username, domain)
log.Debugf("using S4U authentication for domain user %s", fullUsername)
token, err := generateS4UUserToken(username, domain)
if err != nil {
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
}
pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername)
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
return token, nil
}
@@ -548,7 +526,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec
defer func() {
if err := windows.CloseHandle(token); err != nil {
pd.log().Debugf("close impersonation token: %v", err)
log.Debugf("close impersonation token: %v", err)
}
}()
@@ -586,7 +564,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
return cmd, primaryToken, nil
}
// createSuCommand creates a command using su - for privilege switching (Windows stub).
func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) {
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
return nil, fmt.Errorf("su command not available on Windows")
}

View File

@@ -54,7 +54,7 @@ func TestJWTEnforcement(t *testing.T) {
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() { require.NoError(t, server.Stop()) }()
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
@@ -88,7 +88,7 @@ func TestJWTEnforcement(t *testing.T) {
serverNoJWT.SetAllowRootLogin(true)
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
defer func() { require.NoError(t, serverNoJWT.Stop()) }()
defer require.NoError(t, serverNoJWT.Stop())
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
require.NoError(t, err)
@@ -213,7 +213,7 @@ func TestJWTDetection(t *testing.T) {
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() { require.NoError(t, server.Stop()) }()
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
@@ -341,7 +341,7 @@ func TestJWTFailClose(t *testing.T) {
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() { require.NoError(t, server.Stop()) }()
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
@@ -596,7 +596,7 @@ func TestJWTAuthentication(t *testing.T) {
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server)
defer func() { require.NoError(t, server.Stop()) }()
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
@@ -715,7 +715,7 @@ func TestJWTMultipleAudiences(t *testing.T) {
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server)
defer func() { require.NoError(t, server.Stop()) }()
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)

View File

@@ -271,6 +271,13 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
return s.allowRemotePortForwarding
}
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
func (s *Server) isPortForwardingEnabled() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
}
// parseTcpipForwardRequest parses the SSH request payload
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
var payload tcpipForwardMsg

View File

@@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
sessions = append(sessions, info)
}
// Add authenticated connections without sessions (e.g., -N or port-forwarding only)
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
for key, connState := range s.connections {
remoteAddr := string(key)
if reportedAddrs[remoteAddr] {

View File

@@ -483,11 +483,12 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
}
}
func TestServer_NonPtyShellSession(t *testing.T) {
// Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings.
func TestServer_PortForwardingOnlySession(t *testing.T) {
// Test that sessions without PTY and command are allowed when port forwarding is enabled
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
@@ -495,26 +496,36 @@ func TestServer_NonPtyShellSession(t *testing.T) {
name string
allowLocalForwarding bool
allowRemoteForwarding bool
expectAllowed bool
description string
}{
{
name: "shell_with_local_forwarding_enabled",
name: "session_allowed_with_local_forwarding",
allowLocalForwarding: true,
allowRemoteForwarding: false,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
},
{
name: "shell_with_remote_forwarding_enabled",
name: "session_allowed_with_remote_forwarding",
allowLocalForwarding: false,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
},
{
name: "shell_with_both_forwarding_enabled",
name: "session_allowed_with_both",
allowLocalForwarding: true,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
},
{
name: "shell_with_forwarding_disabled",
name: "session_denied_without_forwarding",
allowLocalForwarding: false,
allowRemoteForwarding: false,
expectAllowed: false,
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
},
}
@@ -534,6 +545,7 @@ func TestServer_NonPtyShellSession(t *testing.T) {
_ = server.Stop()
}()
// Connect to the server without requesting PTY or command
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@@ -545,10 +557,20 @@ func TestServer_NonPtyShellSession(t *testing.T) {
_ = client.Close()
}()
// Execute without PTY and no command - simulates ssh -T (shell without PTY)
// Should always succeed regardless of port forwarding settings
_, err = client.ExecuteCommand(ctx, "")
assert.NoError(t, err, "Non-PTY shell session should be allowed")
// Execute a command without PTY - this simulates ssh -T with no command
// The server should either allow it (port forwarding enabled) or reject it
output, err := client.ExecuteCommand(ctx, "")
if tt.expectAllowed {
// When allowed, the session stays open until cancelled
// ExecuteCommand with empty command should return without error
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
assert.NotContains(t, output, "port forwarding is disabled",
"Output should not contain port forwarding disabled message")
} else if err != nil {
// When denied, we expect an error message about port forwarding being disabled
assert.Contains(t, err.Error(), "port forwarding is disabled",
"Should get port forwarding disabled message")
}
})
}
}

View File

@@ -405,14 +405,12 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) {
assert.Equal(t, "-Command", args[1])
assert.Equal(t, "echo test", args[2])
} else {
// Test Unix shell behavior
args := server.getShellCommandArgs("/bin/sh", "echo test")
assert.Equal(t, "/bin/sh", args[0])
assert.Equal(t, "-c", args[1])
assert.Equal(t, "echo test", args[2])
args = server.getShellCommandArgs("/bin/sh", "")
assert.Equal(t, "/bin/sh", args[0])
assert.Len(t, args, 1)
assert.Equal(t, "-l", args[1])
assert.Equal(t, "-c", args[2])
assert.Equal(t, "echo test", args[3])
}
}

View File

@@ -62,12 +62,54 @@ func (s *Server) sessionHandler(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
if isPty && !hasCommand {
// ssh <host> - PTY interactive session (login)
s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh)
} else {
// ssh <host> <cmd>, ssh -t <host> <cmd>, ssh -T <host> - command or shell execution
s.handleExecution(logger, session, privilegeResult, ptyReq, winCh)
switch {
case isPty && hasCommand:
// ssh -t <host> <cmd> - Pty command execution
s.handleCommand(logger, session, privilegeResult, winCh)
case isPty:
// ssh <host> - Pty interactive session (login)
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
case hasCommand:
// ssh <host> <cmd> - non-Pty command execution
s.handleCommand(logger, session, privilegeResult, nil)
default:
// ssh -T (or ssh -N) - no PTY, no command
s.handleNonInteractiveSession(logger, session)
}
}
// handleNonInteractiveSession handles sessions that have no PTY and no command.
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
s.updateSessionType(session, cmdNonInteractive)
if !s.isPortForwardingEnabled() {
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
logger.Debugf(errWriteSession, err)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
logger.Infof("rejected non-interactive session: port forwarding disabled")
return
}
<-session.Context().Done()
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
}
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
s.mu.Lock()
defer s.mu.Unlock()
for _, state := range s.sessions {
if state.session == session {
state.sessionType = sessionType
return
}
}
}

View File

@@ -9,8 +9,8 @@ import (
log "github.com/sirupsen/logrus"
)
// handlePtyLogin is not supported on JS/WASM
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
// handlePty is not supported on JS/WASM
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
logger.Debugf(errWriteSession, err)

View File

@@ -8,18 +8,19 @@ import (
"time"
)
// StartTestServer starts the SSH server and returns the address it's listening on.
func StartTestServer(t *testing.T, server *Server) string {
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
// Use port 0 to let the OS assign a free port
addrPort := netip.MustParseAddrPort("127.0.0.1:0")
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
// Get the actual listening address from the server
actualAddr := server.Addr()
if actualAddr == nil {
errChan <- fmt.Errorf("server started but no listener address available")

View File

@@ -181,8 +181,8 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
// Returns the command and a cleanup function (no-op on Unix).
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
logger.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
if err := validateUsername(localUser.Username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
@@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
if err != nil {
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
}
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
privilegeDropper := NewPrivilegeDropper()
config := ExecutorConfig{
UID: uid,
GID: gid,
@@ -233,7 +233,7 @@ func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.Use
shell := getUserShell(localUser.Uid)
args := s.getShellCommandArgs(shell, session.RawCommand())
cmd := s.createShellCommand(session.Context(), shell, args)
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
cmd.Dir = localUser.HomeDir
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)

View File

@@ -88,20 +88,20 @@ func validateUsernameFormat(username string) error {
// createExecutorCommand creates a command using Windows executor for privilege dropping.
// Returns the command and a cleanup function that must be called after starting the process.
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
logger.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
username, _ := s.parseUsername(localUser.Username)
if err := validateUsername(username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
}
return s.createUserSwitchCommand(logger, session, localUser)
return s.createUserSwitchCommand(localUser, session, hasPty)
}
// createUserSwitchCommand creates a command with Windows user switching.
// Returns the command and a cleanup function that must be called after starting the process.
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
@@ -113,14 +113,15 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session,
}
config := WindowsExecutorConfig{
Username: username,
Domain: domain,
WorkingDir: localUser.HomeDir,
Shell: shell,
Command: command,
Username: username,
Domain: domain,
WorkingDir: localUser.HomeDir,
Shell: shell,
Command: command,
Interactive: interactive || (rawCmd == ""),
}
dropper := NewPrivilegeDropper(WithLogger(logger))
dropper := NewPrivilegeDropper()
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
if err != nil {
return nil, nil, err
@@ -129,7 +130,7 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session,
cleanup := func() {
if token != 0 {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
logger.Debugf("close primary token: %v", err)
log.Debugf("close primary token: %v", err)
}
}
}

View File

@@ -56,7 +56,7 @@ var (
)
// ExecutePtyWithUserToken executes a command with ConPty using user token.
func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
commandLine := buildCommandLine(args)
@@ -64,7 +64,7 @@ func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfi
Pty: ptyConfig,
User: userConfig,
Session: session,
Context: session.Context(),
Context: ctx,
}
return executeConPtyWithConfig(commandLine, config)

View File

@@ -63,8 +63,6 @@ func (h *eventHandler) listen(ctx context.Context) {
h.handleNetworksClick()
case <-h.client.mNotifications.ClickedCh:
h.handleNotificationsClick()
case <-systray.TrayOpenedCh:
h.client.updateExitNodes()
}
}
}

View File

@@ -341,6 +341,7 @@ func (s *serviceClient) updateExitNodes() {
log.Errorf("get client: %v", err)
return
}
exitNodes, err := s.getExitNodes(conn)
if err != nil {
log.Errorf("get exit nodes: %v", err)

2
go.mod
View File

@@ -31,7 +31,7 @@ require (
require (
fyne.io/fyne/v2 v2.7.0
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.36.3

4
go.sum
View File

@@ -13,8 +13,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ=
fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE=
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4=
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 h1:eA5/u2XRd8OUkoMqEv3IBlFYSruNlXD8bRHDiqm0VNI=
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=

View File

@@ -232,9 +232,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
userID, err := s.accountManager.GetUserIDByPeerKey(ctx, peerKey.String())
if err != nil {
s.syncSem.Add(-1)
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return mapError(ctx, err)
}

View File

@@ -30,12 +30,6 @@ type Manager interface {
autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
AcceptUserInvite(ctx context.Context, token, password string) error
RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error)
ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error

View File

@@ -199,11 +199,6 @@ const (
UserPasswordChanged Activity = 103
UserInviteLinkCreated Activity = 104
UserInviteLinkAccepted Activity = 105
UserInviteLinkRegenerated Activity = 106
UserInviteLinkDeleted Activity = 107
AccountDeleted Activity = 99999
)
@@ -332,11 +327,6 @@ var activityMap = map[Activity]Code{
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
UserPasswordChanged: {"User password changed", "user.password.change"},
UserInviteLinkCreated: {"User invite link created", "user.invite.link.create"},
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
}
// StringCode returns a string code of the activity

View File

@@ -68,13 +68,6 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
if err := bypass.AddBypassPath("/api/setup"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
// Public invite endpoints (tokens start with nbi_)
if err := bypass.AddBypassPath("/api/users/invites/nbi_*"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
@@ -139,8 +132,6 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
peers.AddEndpoints(accountManager, router, networkMapController)
users.AddEndpoints(accountManager, router)
users.AddInvitesEndpoints(accountManager, router)
users.AddPublicInvitesEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
@@ -154,7 +145,6 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router)
// Mount embedded IdP handler at /oauth2 path if configured
if embeddedIdpEnabled {

View File

@@ -28,15 +28,6 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS")
}
// AddVersionEndpoint registers the authenticated version endpoint.
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
h := &handler{
instanceManager: instanceManager,
}
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
}
// getInstanceStatus returns the instance status including whether setup is required.
// This endpoint is unauthenticated.
func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
@@ -74,29 +65,3 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
Email: userData.Email,
})
}
// getVersionInfo returns version information for NetBird components.
// This endpoint requires authentication.
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)
util.WriteErrorResponse("failed to get version info", http.StatusInternalServerError, w)
return
}
resp := api.InstanceVersionInfo{
ManagementCurrentVersion: versionInfo.CurrentVersion,
ManagementUpdateAvailable: versionInfo.ManagementUpdateAvailable,
}
if versionInfo.DashboardVersion != "" {
resp.DashboardAvailableVersion = &versionInfo.DashboardVersion
}
if versionInfo.ManagementVersion != "" {
resp.ManagementAvailableVersion = &versionInfo.ManagementVersion
}
util.WriteJSONObject(r.Context(), w, resp)
}

View File

@@ -25,7 +25,6 @@ type mockInstanceManager struct {
isSetupRequired bool
isSetupRequiredFn func(ctx context.Context) (bool, error)
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
}
func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) {
@@ -67,18 +66,6 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
}, nil
}
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
if m.getVersionInfoFn != nil {
return m.getVersionInfoFn(ctx)
}
return &nbinstance.VersionInfo{
CurrentVersion: "0.34.0",
DashboardVersion: "2.0.0",
ManagementVersion: "0.35.0",
ManagementUpdateAvailable: true,
}, nil
}
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
@@ -292,44 +279,3 @@ func TestSetup_ManagerError(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestGetVersionInfo_Success(t *testing.T) {
manager := &mockInstanceManager{}
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.InstanceVersionInfo
err := json.NewDecoder(rec.Body).Decode(&response)
require.NoError(t, err)
assert.Equal(t, "0.34.0", response.ManagementCurrentVersion)
assert.NotNil(t, response.DashboardAvailableVersion)
assert.Equal(t, "2.0.0", *response.DashboardAvailableVersion)
assert.NotNil(t, response.ManagementAvailableVersion)
assert.Equal(t, "0.35.0", *response.ManagementAvailableVersion)
assert.True(t, response.ManagementUpdateAvailable)
}
func TestGetVersionInfo_Error(t *testing.T) {
manager := &mockInstanceManager{
getVersionInfoFn: func(ctx context.Context) (*nbinstance.VersionInfo, error) {
return nil, errors.New("failed to fetch versions")
},
}
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}

View File

@@ -1,263 +0,0 @@
package users
import (
"encoding/json"
"errors"
"io"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// publicInviteRateLimiter limits public invite requests by IP address to prevent brute-force attacks
var publicInviteRateLimiter = middleware.NewAPIRateLimiter(&middleware.RateLimiterConfig{
RequestsPerMinute: 10, // 10 attempts per minute per IP
Burst: 5, // Allow burst of 5 requests
CleanupInterval: 10 * time.Minute,
LimiterTTL: 30 * time.Minute,
})
// toUserInviteResponse converts a UserInvite to an API response.
func toUserInviteResponse(invite *types.UserInvite) api.UserInvite {
autoGroups := invite.UserInfo.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
var inviteLink *string
if invite.InviteToken != "" {
inviteLink = &invite.InviteToken
}
return api.UserInvite{
Id: invite.UserInfo.ID,
Email: invite.UserInfo.Email,
Name: invite.UserInfo.Name,
Role: invite.UserInfo.Role,
AutoGroups: autoGroups,
ExpiresAt: invite.InviteExpiresAt.UTC(),
CreatedAt: invite.InviteCreatedAt.UTC(),
Expired: time.Now().After(invite.InviteExpiresAt),
InviteToken: inviteLink,
}
}
// invitesHandler handles user invite operations
type invitesHandler struct {
accountManager account.Manager
}
// AddInvitesEndpoints registers invite-related endpoints
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
h := &invitesHandler{accountManager: accountManager}
// Authenticated endpoints (require admin)
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
}
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
h := &invitesHandler{accountManager: accountManager}
// Create a subrouter for public invite endpoints with rate limiting middleware
publicRouter := router.PathPrefix("/users/invites").Subrouter()
publicRouter.Use(publicInviteRateLimiter.Middleware)
// Public endpoints (no auth required, protected by token and rate limited)
publicRouter.HandleFunc("/{token}", h.getInviteInfo).Methods("GET", "OPTIONS")
publicRouter.HandleFunc("/{token}/accept", h.acceptInvite).Methods("POST", "OPTIONS")
}
// listInvites handles GET /api/users/invites
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := make([]api.UserInvite, 0, len(invites))
for _, invite := range invites {
resp = append(resp, toUserInviteResponse(invite))
}
util.WriteJSONObject(r.Context(), w, resp)
}
// createInvite handles POST /api/users/invites
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.UserInviteCreateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
invite := &types.UserInfo{
Email: req.Email,
Name: req.Name,
Role: req.Role,
AutoGroups: req.AutoGroups,
}
expiresIn := 0
if req.ExpiresIn != nil {
expiresIn = *req.ExpiresIn
}
result, err := h.accountManager.CreateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, invite, expiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
result.InviteCreatedAt = time.Now().UTC()
resp := toUserInviteResponse(result)
util.WriteJSONObject(r.Context(), w, &resp)
}
// getInviteInfo handles GET /api/users/invites/{token}
func (h *invitesHandler) getInviteInfo(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
token := vars["token"]
if token == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
return
}
info, err := h.accountManager.GetUserInviteInfo(r.Context(), token)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
expiresAt := info.ExpiresAt.UTC()
util.WriteJSONObject(r.Context(), w, &api.UserInviteInfo{
Email: info.Email,
Name: info.Name,
ExpiresAt: expiresAt,
Valid: info.Valid,
InvitedBy: info.InvitedBy,
})
}
// acceptInvite handles POST /api/users/invites/{token}/accept
func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
token := vars["token"]
if token == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
return
}
var req api.UserInviteAcceptRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err := h.accountManager.AcceptUserInvite(r.Context(), token, req.Password)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, &api.UserInviteAcceptResponse{Success: true})
}
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
return
}
var req api.UserInviteRegenerateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// Allow empty body (io.EOF) - expiresIn is optional
if !errors.Is(err, io.EOF) {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
}
expiresIn := 0
if req.ExpiresIn != nil {
expiresIn = *req.ExpiresIn
}
result, err := h.accountManager.RegenerateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID, expiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
expiresAt := result.InviteExpiresAt.UTC()
util.WriteJSONObject(r.Context(), w, &api.UserInviteRegenerateResponse{
InviteToken: result.InviteToken,
InviteExpiresAt: expiresAt,
})
}
// deleteInvite handles DELETE /api/users/invites/{inviteId}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
return
}
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -1,642 +0,0 @@
package users
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testInviteID = "test-invite-id"
testInviteToken = "nbi_testtoken123456789012345678"
testEmail = "invite@example.com"
testName = "Test User"
)
func setupInvitesTestHandler(am *mock_server.MockAccountManager) *invitesHandler {
return &invitesHandler{
accountManager: am,
}
}
func TestListInvites(t *testing.T) {
now := time.Now().UTC()
testInvites := []*types.UserInvite{
{
UserInfo: &types.UserInfo{
ID: "invite-1",
Email: "user1@example.com",
Name: "User One",
Role: "user",
AutoGroups: []string{"group-1"},
},
InviteExpiresAt: now.Add(24 * time.Hour),
InviteCreatedAt: now,
},
{
UserInfo: &types.UserInfo{
ID: "invite-2",
Email: "user2@example.com",
Name: "User Two",
Role: "admin",
AutoGroups: nil,
},
InviteExpiresAt: now.Add(-1 * time.Hour), // Expired
InviteCreatedAt: now.Add(-48 * time.Hour),
},
}
tt := []struct {
name string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
expectedCount int
}{
{
name: "successful list",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return testInvites, nil
},
expectedCount: 2,
},
{
name: "empty list",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return []*types.UserInvite{}, nil
},
expectedCount: 0,
},
{
name: "permission denied",
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
expectedCount: 0,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
ListUserInvitesFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodGet, "/api/users/invites", nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
rr := httptest.NewRecorder()
handler.listInvites(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp []api.UserInvite
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Len(t, resp, tc.expectedCount)
}
})
}
}
func TestCreateInvite(t *testing.T) {
now := time.Now().UTC()
expiresAt := now.Add(72 * time.Hour)
tt := []struct {
name string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
}{
{
name: "successful create",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":["group-1"]}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: testInviteID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: invite.AutoGroups,
Status: string(types.UserStatusInvited),
},
InviteToken: testInviteToken,
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "successful create with custom expiration",
requestBody: `{"email":"test@example.com","name":"Test User","role":"admin","auto_groups":[],"expires_in":3600}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 3600, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: testInviteID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: []string{},
Status: string(types.UserStatusInvited),
},
InviteToken: testInviteToken,
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "user already exists",
requestBody: `{"email":"existing@example.com","name":"Existing User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusConflict,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
},
},
{
name: "invite already exists",
requestBody: `{"email":"invited@example.com","name":"Invited User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusConflict,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
},
},
{
name: "permission denied",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
},
{
name: "embedded IDP not enabled",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "invalid JSON",
requestBody: `{invalid json}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
CreateUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodPost, "/api/users/invites", bytes.NewBufferString(tc.requestBody))
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
rr := httptest.NewRecorder()
handler.createInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInvite
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, testInviteID, resp.Id)
assert.NotNil(t, resp.InviteToken)
assert.NotEmpty(t, *resp.InviteToken)
}
})
}
}
func TestGetInviteInfo(t *testing.T) {
now := time.Now().UTC()
tt := []struct {
name string
token string
expectedStatus int
mockFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
}{
{
name: "successful get valid invite",
token: testInviteToken,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return &types.UserInviteInfo{
Email: testEmail,
Name: testName,
ExpiresAt: now.Add(24 * time.Hour),
Valid: true,
InvitedBy: "Admin User",
}, nil
},
},
{
name: "successful get expired invite",
token: testInviteToken,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return &types.UserInviteInfo{
Email: testEmail,
Name: testName,
ExpiresAt: now.Add(-24 * time.Hour),
Valid: false,
InvitedBy: "Admin User",
}, nil
},
},
{
name: "invite not found",
token: "nbi_invalidtoken1234567890123456",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return nil, status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "invalid token format",
token: "invalid",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return nil, status.Errorf(status.InvalidArgument, "invalid invite token")
},
},
{
name: "missing token",
token: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
GetUserInviteInfoFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodGet, "/api/users/invites/"+tc.token, nil)
if tc.token != "" {
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
}
rr := httptest.NewRecorder()
handler.getInviteInfo(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteInfo
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, testEmail, resp.Email)
assert.Equal(t, testName, resp.Name)
}
})
}
}
func TestAcceptInvite(t *testing.T) {
tt := []struct {
name string
token string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, token, password string) error
}{
{
name: "successful accept",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token, password string) error {
return nil
},
},
{
name: "invite not found",
token: "nbi_invalidtoken1234567890123456",
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "invite expired",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "invite has expired")
},
},
{
name: "embedded IDP not enabled",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "missing token",
token: "",
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
{
name: "invalid JSON",
token: testInviteToken,
requestBody: `{invalid}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
{
name: "password too short",
token: testInviteToken,
requestBody: `{"password":"Short1!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters long")
},
},
{
name: "password missing digit",
token: testInviteToken,
requestBody: `{"password":"NoDigitPass!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one digit")
},
},
{
name: "password missing uppercase",
token: testInviteToken,
requestBody: `{"password":"nouppercase1!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one uppercase letter")
},
},
{
name: "password missing special character",
token: testInviteToken,
requestBody: `{"password":"NoSpecial123"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one special character")
},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
AcceptUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.token+"/accept", bytes.NewBufferString(tc.requestBody))
if tc.token != "" {
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
}
rr := httptest.NewRecorder()
handler.acceptInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteAcceptResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.True(t, resp.Success)
}
})
}
}
func TestRegenerateInvite(t *testing.T) {
now := time.Now().UTC()
expiresAt := now.Add(72 * time.Hour)
tt := []struct {
name string
inviteID string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
}{
{
name: "successful regenerate with empty body",
inviteID: testInviteID,
requestBody: "",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 0, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: inviteID,
Email: testEmail,
},
InviteToken: "nbi_newtoken12345678901234567890",
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "successful regenerate with custom expiration",
inviteID: testInviteID,
requestBody: `{"expires_in":7200}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 7200, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: inviteID,
Email: testEmail,
},
InviteToken: "nbi_newtoken12345678901234567890",
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "invite not found",
inviteID: "non-existent-invite",
requestBody: "",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "permission denied",
inviteID: testInviteID,
requestBody: "",
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
},
{
name: "missing invite ID",
inviteID: "",
requestBody: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
{
name: "invalid JSON should return error",
inviteID: testInviteID,
requestBody: `{invalid json}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
RegenerateUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
var body io.Reader
if tc.requestBody != "" {
body = bytes.NewBufferString(tc.requestBody)
}
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.inviteID+"/regenerate", body)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
if tc.inviteID != "" {
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
}
rr := httptest.NewRecorder()
handler.regenerateInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteRegenerateResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.NotEmpty(t, resp.InviteToken)
}
})
}
}
func TestDeleteInvite(t *testing.T) {
tt := []struct {
name string
inviteID string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}{
{
name: "successful delete",
inviteID: testInviteID,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return nil
},
},
{
name: "invite not found",
inviteID: "non-existent-invite",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "permission denied",
inviteID: testInviteID,
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.NewPermissionDeniedError()
},
},
{
name: "embedded IDP not enabled",
inviteID: testInviteID,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "missing invite ID",
inviteID: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
DeleteUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodDelete, "/api/users/invites/"+tc.inviteID, nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
if tc.inviteID != "" {
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
}
rr := httptest.NewRecorder()
handler.deleteInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
})
}
}

View File

@@ -2,14 +2,10 @@ package middleware
import (
"context"
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util"
)
// RateLimiterConfig holds configuration for the API rate limiter
@@ -148,25 +144,3 @@ func (rl *APIRateLimiter) Reset(key string) {
defer rl.mu.Unlock()
delete(rl.limiters, key)
}
// Middleware returns an HTTP middleware that rate limits requests by client IP.
// Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
return
}
next.ServeHTTP(w, r)
})
}
// getClientIP extracts the client IP address from the request.
func getClientIP(r *http.Request) string {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}

View File

@@ -1,158 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestAPIRateLimiter_Allow(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// First two requests should be allowed (burst)
assert.True(t, rl.Allow("test-key"))
assert.True(t, rl.Allow("test-key"))
// Third request should be denied (exceeded burst)
assert.False(t, rl.Allow("test-key"))
// Different key should be allowed
assert.True(t, rl.Allow("different-key"))
}
func TestAPIRateLimiter_Middleware(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Create a simple handler that returns 200 OK
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Wrap with rate limiter middleware
handler := rl.Middleware(nextHandler)
// First two requests should pass (burst)
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1)
}
// Third request should be rate limited
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
}
func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := rl.Middleware(nextHandler)
// Request from first IP
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
req1.RemoteAddr = "192.168.1.1:12345"
rr1 := httptest.NewRecorder()
handler.ServeHTTP(rr1, req1)
assert.Equal(t, http.StatusOK, rr1.Code)
// Second request from first IP should be rate limited
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
req2.RemoteAddr = "192.168.1.1:12345"
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusTooManyRequests, rr2.Code)
// Request from different IP should be allowed
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
req3.RemoteAddr = "192.168.1.2:12345"
rr3 := httptest.NewRecorder()
handler.ServeHTTP(rr3, req3)
assert.Equal(t, http.StatusOK, rr3.Code)
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
expected string
}{
{
name: "remote addr with port",
remoteAddr: "192.168.1.1:12345",
expected: "192.168.1.1",
},
{
name: "remote addr without port",
remoteAddr: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "IPv6 with port",
remoteAddr: "[::1]:12345",
expected: "::1",
},
{
name: "IPv6 without port",
remoteAddr: "::1",
expected: "::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tc.remoteAddr
assert.Equal(t, tc.expected, getClientIP(req))
})
}
}
func TestAPIRateLimiter_Reset(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Use up the burst
assert.True(t, rl.Allow("test-key"))
assert.False(t, rl.Allow("test-key"))
// Reset the limiter
rl.Reset("test-key")
// Should be allowed again
assert.True(t, rl.Allow("test-key"))
}

View File

@@ -20,7 +20,7 @@ const (
staticClientCLI = "netbird-cli"
defaultCLIRedirectURL1 = "http://localhost:53000/"
defaultCLIRedirectURL2 = "http://localhost:54000/"
defaultScopes = "openid profile email groups"
defaultScopes = "openid profile email"
defaultUserIDClaim = "sub"
)

View File

@@ -2,54 +2,18 @@ package instance
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/mail"
"strings"
"sync"
"time"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/version"
)
const (
// Version endpoints
managementVersionURL = "https://pkgs.netbird.io/releases/latest/version"
dashboardReleasesURL = "https://api.github.com/repos/netbirdio/dashboard/releases/latest"
// Cache TTL for version information
versionCacheTTL = 60 * time.Minute
// HTTP client timeout
httpTimeout = 5 * time.Second
)
// VersionInfo contains version information for NetBird components
type VersionInfo struct {
// CurrentVersion is the running management server version
CurrentVersion string
// DashboardVersion is the latest available dashboard version from GitHub
DashboardVersion string
// ManagementVersion is the latest available management version from GitHub
ManagementVersion string
// ManagementUpdateAvailable indicates if a newer management version is available
ManagementUpdateAvailable bool
}
// githubRelease represents a GitHub release response
type githubRelease struct {
TagName string `json:"tag_name"`
}
// Manager handles instance-level operations like initial setup.
type Manager interface {
// IsSetupRequired checks if instance setup is required.
@@ -59,9 +23,6 @@ type Manager interface {
// CreateOwnerUser creates the initial owner user in the embedded IDP.
// This should only be called when IsSetupRequired returns true.
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
// GetVersionInfo returns version information for NetBird components.
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
}
// DefaultManager is the default implementation of Manager.
@@ -71,12 +32,6 @@ type DefaultManager struct {
setupRequired bool
setupMu sync.RWMutex
// Version caching
httpClient *http.Client
versionMu sync.RWMutex
cachedVersions *VersionInfo
lastVersionFetch time.Time
}
// NewManager creates a new instance manager.
@@ -88,9 +43,6 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager)
store: store,
embeddedIdpManager: embeddedIdp,
setupRequired: false,
httpClient: &http.Client{
Timeout: httpTimeout,
},
}
if embeddedIdp != nil {
@@ -182,130 +134,3 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
}
return nil
}
// GetVersionInfo returns version information for NetBird components.
func (m *DefaultManager) GetVersionInfo(ctx context.Context) (*VersionInfo, error) {
m.versionMu.RLock()
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
cached := *m.cachedVersions
m.versionMu.RUnlock()
return &cached, nil
}
m.versionMu.RUnlock()
return m.fetchVersionInfo(ctx)
}
func (m *DefaultManager) fetchVersionInfo(ctx context.Context) (*VersionInfo, error) {
m.versionMu.Lock()
// Double-check after acquiring write lock
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
cached := *m.cachedVersions
m.versionMu.Unlock()
return &cached, nil
}
m.versionMu.Unlock()
info := &VersionInfo{
CurrentVersion: version.NetbirdVersion(),
}
// Fetch management version from pkgs.netbird.io (plain text)
mgmtVersion, err := m.fetchPlainTextVersion(ctx, managementVersionURL)
if err != nil {
log.WithContext(ctx).Warnf("failed to fetch management version: %v", err)
} else {
info.ManagementVersion = mgmtVersion
info.ManagementUpdateAvailable = isNewerVersion(info.CurrentVersion, mgmtVersion)
}
// Fetch dashboard version from GitHub
dashVersion, err := m.fetchGitHubRelease(ctx, dashboardReleasesURL)
if err != nil {
log.WithContext(ctx).Warnf("failed to fetch dashboard version from GitHub: %v", err)
} else {
info.DashboardVersion = dashVersion
}
// Update cache
m.versionMu.Lock()
m.cachedVersions = info
m.lastVersionFetch = time.Now()
m.versionMu.Unlock()
return info, nil
}
// isNewerVersion returns true if latestVersion is greater than currentVersion
func isNewerVersion(currentVersion, latestVersion string) bool {
current, err := goversion.NewVersion(currentVersion)
if err != nil {
return false
}
latest, err := goversion.NewVersion(latestVersion)
if err != nil {
return false
}
return latest.GreaterThan(current)
}
func (m *DefaultManager) fetchPlainTextVersion(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
resp, err := m.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("execute request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 100))
if err != nil {
return "", fmt.Errorf("read response: %w", err)
}
return strings.TrimSpace(string(body)), nil
}
func (m *DefaultManager) fetchGitHubRelease(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
resp, err := m.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("execute request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var release githubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return "", fmt.Errorf("decode response: %w", err)
}
// Remove 'v' prefix if present
tag := release.TagName
if len(tag) > 0 && tag[0] == 'v' {
tag = tag[1:]
}
return tag, nil
}

View File

@@ -1,285 +0,0 @@
package instance
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockRoundTripper implements http.RoundTripper for testing
type mockRoundTripper struct {
callCount atomic.Int32
managementVersion string
dashboardVersion string
}
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
m.callCount.Add(1)
var body string
if strings.Contains(req.URL.String(), "pkgs.netbird.io") {
// Plain text response for management version
body = m.managementVersion
} else if strings.Contains(req.URL.String(), "github.com") {
// JSON response for dashboard version
jsonResp, _ := json.Marshal(githubRelease{TagName: "v" + m.dashboardVersion})
body = string(jsonResp)
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(body)),
Header: make(http.Header),
}, nil
}
func TestDefaultManager_GetVersionInfo_ReturnsCurrentVersion(t *testing.T) {
mockTransport := &mockRoundTripper{
managementVersion: "0.65.0",
dashboardVersion: "2.10.0",
}
m := &DefaultManager{
httpClient: &http.Client{Transport: mockTransport},
}
ctx := context.Background()
info, err := m.GetVersionInfo(ctx)
require.NoError(t, err)
// CurrentVersion should always be set
assert.NotEmpty(t, info.CurrentVersion)
assert.Equal(t, "0.65.0", info.ManagementVersion)
assert.Equal(t, "2.10.0", info.DashboardVersion)
assert.Equal(t, int32(2), mockTransport.callCount.Load()) // 2 calls: management + dashboard
}
func TestDefaultManager_GetVersionInfo_CachesResults(t *testing.T) {
mockTransport := &mockRoundTripper{
managementVersion: "0.65.0",
dashboardVersion: "2.10.0",
}
m := &DefaultManager{
httpClient: &http.Client{Transport: mockTransport},
}
ctx := context.Background()
// First call
info1, err := m.GetVersionInfo(ctx)
require.NoError(t, err)
assert.NotEmpty(t, info1.CurrentVersion)
assert.Equal(t, "0.65.0", info1.ManagementVersion)
initialCallCount := mockTransport.callCount.Load()
// Second call should use cache (no additional HTTP calls)
info2, err := m.GetVersionInfo(ctx)
require.NoError(t, err)
assert.Equal(t, info1.CurrentVersion, info2.CurrentVersion)
assert.Equal(t, info1.ManagementVersion, info2.ManagementVersion)
assert.Equal(t, info1.DashboardVersion, info2.DashboardVersion)
// Verify no additional HTTP calls were made (cache was used)
assert.Equal(t, initialCallCount, mockTransport.callCount.Load())
}
func TestDefaultManager_FetchGitHubRelease_ParsesTagName(t *testing.T) {
tests := []struct {
name string
tagName string
expected string
shouldError bool
}{
{
name: "tag with v prefix",
tagName: "v1.2.3",
expected: "1.2.3",
},
{
name: "tag without v prefix",
tagName: "1.2.3",
expected: "1.2.3",
},
{
name: "tag with prerelease",
tagName: "v2.0.0-beta.1",
expected: "2.0.0-beta.1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(githubRelease{TagName: tc.tagName})
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
version, err := m.fetchGitHubRelease(context.Background(), server.URL)
if tc.shouldError {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tc.expected, version)
}
})
}
}
func TestDefaultManager_FetchGitHubRelease_HandlesErrors(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
}{
{
name: "not found",
statusCode: http.StatusNotFound,
body: `{"message": "Not Found"}`,
},
{
name: "rate limited",
statusCode: http.StatusForbidden,
body: `{"message": "API rate limit exceeded"}`,
},
{
name: "server error",
statusCode: http.StatusInternalServerError,
body: `{"message": "Internal Server Error"}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tc.statusCode)
_, _ = w.Write([]byte(tc.body))
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
assert.Error(t, err)
})
}
}
func TestDefaultManager_FetchGitHubRelease_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{invalid json}`))
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
assert.Error(t, err)
}
func TestDefaultManager_FetchGitHubRelease_ContextCancellation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(1 * time.Second)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(githubRelease{TagName: "v1.0.0"})
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := m.fetchGitHubRelease(ctx, server.URL)
assert.Error(t, err)
}
func TestIsNewerVersion(t *testing.T) {
tests := []struct {
name string
currentVersion string
latestVersion string
expected bool
}{
{
name: "latest is newer - minor version",
currentVersion: "0.64.1",
latestVersion: "0.65.0",
expected: true,
},
{
name: "latest is newer - patch version",
currentVersion: "0.64.1",
latestVersion: "0.64.2",
expected: true,
},
{
name: "latest is newer - major version",
currentVersion: "0.64.1",
latestVersion: "1.0.0",
expected: true,
},
{
name: "versions are equal",
currentVersion: "0.64.1",
latestVersion: "0.64.1",
expected: false,
},
{
name: "current is newer - minor version",
currentVersion: "0.65.0",
latestVersion: "0.64.1",
expected: false,
},
{
name: "current is newer - patch version",
currentVersion: "0.64.2",
latestVersion: "0.64.1",
expected: false,
},
{
name: "development version",
currentVersion: "development",
latestVersion: "0.65.0",
expected: false,
},
{
name: "invalid latest version",
currentVersion: "0.64.1",
latestVersion: "invalid",
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := isNewerVersion(tc.currentVersion, tc.latestVersion)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@@ -139,12 +139,6 @@ type MockAccountManager struct {
CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
CreateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
AcceptUserInviteFunc func(ctx context.Context, token, password string) error
RegenerateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
GetUserInviteInfoFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
ListUserInvitesFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
@@ -719,48 +713,6 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
func (am *MockAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
if am.CreateUserInviteFunc != nil {
return am.CreateUserInviteFunc(ctx, accountID, initiatorUserID, invite, expiresIn)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateUserInvite is not implemented")
}
func (am *MockAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error {
if am.AcceptUserInviteFunc != nil {
return am.AcceptUserInviteFunc(ctx, token, password)
}
return status.Errorf(codes.Unimplemented, "method AcceptUserInvite is not implemented")
}
func (am *MockAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
if am.RegenerateUserInviteFunc != nil {
return am.RegenerateUserInviteFunc(ctx, accountID, initiatorUserID, inviteID, expiresIn)
}
return nil, status.Errorf(codes.Unimplemented, "method RegenerateUserInvite is not implemented")
}
func (am *MockAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) {
if am.GetUserInviteInfoFunc != nil {
return am.GetUserInviteInfoFunc(ctx, token)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUserInviteInfo is not implemented")
}
func (am *MockAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
if am.ListUserInvitesFunc != nil {
return am.ListUserInvitesFunc(ctx, accountID, initiatorUserID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListUserInvites is not implemented")
}
func (am *MockAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
if am.DeleteUserInviteFunc != nil {
return am.DeleteUserInviteFunc(ctx, accountID, initiatorUserID, inviteID)
}
return status.Errorf(codes.Unimplemented, "method DeleteUserInvite is not implemented")
}
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
if am.GetAccountIDFromUserAuthFunc != nil {
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)

View File

@@ -150,8 +150,26 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC()
now := time.Now().UTC()
newStatus.LastSeen = now
newStatus.Connected = connected
if oldStatus.Connected == connected {
log.WithContext(ctx).Warnf("peer %s status race: already connected=%t (lastSeen=%s, now=%s, ephemeral=%t)",
peer.ID,
connected,
oldStatus.LastSeen.Format(time.RFC3339Nano),
now.Format(time.RFC3339Nano),
peer.Ephemeral)
}
if !connected && oldStatus.Connected && peer.Ephemeral {
log.WithContext(ctx).Tracef("ephemeral peer %s disconnecting (lastSeen=%s, now=%s)",
peer.ID,
oldStatus.LastSeen.Format(time.RFC3339Nano),
now.Format(time.RFC3339Nano))
}
// whenever peer got connected that means that it logged in successfully
if newStatus.Connected {
newStatus.LoginExpired = false

View File

@@ -126,7 +126,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{},
&types.Job{}, &zones.Zone{}, &records.Record{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -815,130 +815,6 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
return &user, nil
}
// SaveUserInvite saves a user invite to the database
func (s *SqlStore) SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error {
inviteCopy := invite.Copy()
if err := inviteCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt invite: %w", err)
}
result := s.db.Save(inviteCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save user invite to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save user invite to store")
}
return nil
}
// GetUserInviteByID retrieves a user invite by its ID and account ID
func (s *SqlStore) GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invite types.UserInviteRecord
result := tx.Where("account_id = ?", accountID).Take(&invite, idQueryCondition, inviteID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user invite not found")
}
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
}
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
return &invite, nil
}
// GetUserInviteByHashedToken retrieves a user invite by its hashed token
func (s *SqlStore) GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invite types.UserInviteRecord
result := tx.Take(&invite, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user invite not found")
}
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
}
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
return &invite, nil
}
// GetUserInviteByEmail retrieves a user invite by account ID and email.
// Since email is encrypted with random IVs, we fetch all invites for the account
// and compare emails in memory after decryption.
func (s *SqlStore) GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invites []*types.UserInviteRecord
result := tx.Find(&invites, "account_id = ?", accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
}
for _, invite := range invites {
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
if strings.EqualFold(invite.Email, email) {
return invite, nil
}
}
return nil, status.Errorf(status.NotFound, "user invite not found for email")
}
// GetAccountUserInvites retrieves all user invites for an account
func (s *SqlStore) GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invites []*types.UserInviteRecord
result := tx.Find(&invites, "account_id = ?", accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
}
for _, invite := range invites {
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
}
return invites, nil
}
// DeleteUserInvite deletes a user invite by its ID
func (s *SqlStore) DeleteUserInvite(ctx context.Context, inviteID string) error {
result := s.db.Delete(&types.UserInviteRecord{}, idQueryCondition, inviteID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete user invite from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete user invite from store")
}
return nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
@@ -4393,9 +4269,6 @@ func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingS
Take(&userID, GetKeyQueryCondition(s), peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "peer not found: index lookup failed")
}
return "", status.Errorf(status.Internal, "failed to get user ID by peer key")
}

View File

@@ -1,520 +0,0 @@
package store
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/types"
)
func TestSqlStore_SaveUserInvite(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-1",
AccountID: "account-1",
Email: "test@example.com",
Name: "Test User",
Role: "user",
AutoGroups: []string{"group-1", "group-2"},
HashedToken: "hashed-token-123",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Verify the invite was saved
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
assert.Equal(t, invite.Email, retrieved.Email)
assert.Equal(t, invite.Name, retrieved.Name)
assert.Equal(t, invite.Role, retrieved.Role)
assert.Equal(t, invite.AutoGroups, retrieved.AutoGroups)
assert.Equal(t, invite.CreatedBy, retrieved.CreatedBy)
})
}
func TestSqlStore_SaveUserInvite_Update(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-update",
AccountID: "account-1",
Email: "test@example.com",
Name: "Test User",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-123",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Update the invite with a new token
invite.HashedToken = "new-hashed-token"
invite.ExpiresAt = time.Now().Add(24 * time.Hour)
err = store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Verify the update
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, "new-hashed-token", retrieved.HashedToken)
})
}
func TestSqlStore_GetUserInviteByID(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-get-by-id",
AccountID: "account-1",
Email: "getbyid@example.com",
Name: "Get By ID User",
Role: "admin",
AutoGroups: []string{},
HashedToken: "hashed-token-get",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Get by ID - success
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
assert.Equal(t, invite.Email, retrieved.Email)
// Get by ID - wrong account
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, "wrong-account", invite.ID)
assert.Error(t, err)
// Get by ID - not found
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, "non-existent")
assert.Error(t, err)
})
}
func TestSqlStore_GetUserInviteByHashedToken(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-get-by-token",
AccountID: "account-1",
Email: "getbytoken@example.com",
Name: "Get By Token User",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "unique-hashed-token-456",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Get by hashed token - success
retrieved, err := store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, invite.HashedToken)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
assert.Equal(t, invite.Email, retrieved.Email)
// Get by hashed token - not found
_, err = store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, "non-existent-token")
assert.Error(t, err)
})
}
func TestSqlStore_GetUserInviteByEmail(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-get-by-email",
AccountID: "account-email-test",
Email: "unique-email@example.com",
Name: "Get By Email User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-email",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Get by email - success
retrieved, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, invite.Email)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
// Get by email - case insensitive
retrieved, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "UNIQUE-EMAIL@EXAMPLE.COM")
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
// Get by email - wrong account
_, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, "wrong-account", invite.Email)
assert.Error(t, err)
// Get by email - not found
_, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "nonexistent@example.com")
assert.Error(t, err)
})
}
func TestSqlStore_GetAccountUserInvites(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
accountID := "account-list-invites"
invites := []*types.UserInviteRecord{
{
ID: "invite-list-1",
AccountID: accountID,
Email: "user1@example.com",
Name: "User One",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-list-1",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
},
{
ID: "invite-list-2",
AccountID: accountID,
Email: "user2@example.com",
Name: "User Two",
Role: "admin",
AutoGroups: []string{"group-2"},
HashedToken: "hashed-token-list-2",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
},
{
ID: "invite-list-3",
AccountID: "different-account",
Email: "user3@example.com",
Name: "User Three",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-list-3",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
},
}
for _, invite := range invites {
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
}
// Get all invites for the account
retrieved, err := store.GetAccountUserInvites(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
assert.Len(t, retrieved, 2)
// Verify the invites belong to the correct account
for _, invite := range retrieved {
assert.Equal(t, accountID, invite.AccountID)
}
// Get invites for account with no invites
retrieved, err = store.GetAccountUserInvites(ctx, LockingStrengthNone, "empty-account")
require.NoError(t, err)
assert.Len(t, retrieved, 0)
})
}
func TestSqlStore_DeleteUserInvite(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-delete",
AccountID: "account-delete-test",
Email: "delete@example.com",
Name: "Delete User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-delete",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Verify invite exists
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
// Delete the invite
err = store.DeleteUserInvite(ctx, invite.ID)
require.NoError(t, err)
// Verify invite is deleted
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
assert.Error(t, err)
})
}
func TestSqlStore_UserInvite_EncryptedFields(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-encrypted",
AccountID: "account-encrypted",
Email: "sensitive-email@example.com",
Name: "Sensitive Name",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-encrypted",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Retrieve and verify decryption works
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, "sensitive-email@example.com", retrieved.Email)
assert.Equal(t, "Sensitive Name", retrieved.Name)
})
}
func TestSqlStore_DeleteUserInvite_NonExistent(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
// Deleting a non-existent invite should not return an error
err := store.DeleteUserInvite(ctx, "non-existent-invite-id")
require.NoError(t, err)
})
}
func TestSqlStore_UserInvite_SameEmailDifferentAccounts(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
email := "shared-email@example.com"
// Create invite in first account
invite1 := &types.UserInviteRecord{
ID: "invite-account1",
AccountID: "account-1",
Email: email,
Name: "User Account 1",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-account1",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-1",
}
// Create invite in second account with same email
invite2 := &types.UserInviteRecord{
ID: "invite-account2",
AccountID: "account-2",
Email: email,
Name: "User Account 2",
Role: "admin",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-account2",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-2",
}
err := store.SaveUserInvite(ctx, invite1)
require.NoError(t, err)
err = store.SaveUserInvite(ctx, invite2)
require.NoError(t, err)
// Verify each account gets the correct invite by email
retrieved1, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-1", email)
require.NoError(t, err)
assert.Equal(t, "invite-account1", retrieved1.ID)
assert.Equal(t, "User Account 1", retrieved1.Name)
retrieved2, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-2", email)
require.NoError(t, err)
assert.Equal(t, "invite-account2", retrieved2.ID)
assert.Equal(t, "User Account 2", retrieved2.Name)
})
}
func TestSqlStore_UserInvite_LockingStrength(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-locking",
AccountID: "account-locking",
Email: "locking@example.com",
Name: "Locking Test User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-locking",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Test with different locking strengths
lockStrengths := []LockingStrength{LockingStrengthNone, LockingStrengthShare, LockingStrengthUpdate}
for _, strength := range lockStrengths {
retrieved, err := store.GetUserInviteByID(ctx, strength, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
retrieved, err = store.GetUserInviteByHashedToken(ctx, strength, invite.HashedToken)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
retrieved, err = store.GetUserInviteByEmail(ctx, strength, invite.AccountID, invite.Email)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
invites, err := store.GetAccountUserInvites(ctx, strength, invite.AccountID)
require.NoError(t, err)
assert.Len(t, invites, 1)
}
})
}
func TestSqlStore_UserInvite_EmptyAutoGroups(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
// Test with nil AutoGroups
invite := &types.UserInviteRecord{
ID: "invite-nil-autogroups",
AccountID: "account-autogroups",
Email: "nilgroups@example.com",
Name: "Nil Groups User",
Role: "user",
AutoGroups: nil,
HashedToken: "hashed-token-nil",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
// Should return empty slice or nil, both are acceptable
assert.Empty(t, retrieved.AutoGroups)
})
}
func TestSqlStore_UserInvite_TimestampPrecision(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
now := time.Now().UTC().Truncate(time.Millisecond)
expiresAt := now.Add(72 * time.Hour)
invite := &types.UserInviteRecord{
ID: "invite-timestamp",
AccountID: "account-timestamp",
Email: "timestamp@example.com",
Name: "Timestamp User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-timestamp",
ExpiresAt: expiresAt,
CreatedAt: now,
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
// Verify timestamps are preserved (within reasonable precision)
assert.WithinDuration(t, now, retrieved.CreatedAt, time.Second)
assert.WithinDuration(t, expiresAt, retrieved.ExpiresAt, time.Second)
})
}

View File

@@ -92,13 +92,6 @@ type Store interface {
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error
GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error)
GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error)
GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error)
GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error)
DeleteUserInvite(ctx context.Context, inviteID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)

View File

@@ -1,201 +0,0 @@
package types
import (
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"hash/crc32"
"strings"
"time"
b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62"
"github.com/netbirdio/netbird/util/crypt"
)
const (
// InviteTokenPrefix is the prefix for invite tokens
InviteTokenPrefix = "nbi_"
// InviteTokenSecretLength is the length of the random secret part
InviteTokenSecretLength = 30
// InviteTokenChecksumLength is the length of the encoded checksum
InviteTokenChecksumLength = 6
// InviteTokenLength is the total length of the token (4 + 30 + 6 = 40)
InviteTokenLength = 40
// DefaultInviteExpirationSeconds is the default expiration time for invites (72 hours)
DefaultInviteExpirationSeconds = 259200
// MinInviteExpirationSeconds is the minimum expiration time for invites (1 hour)
MinInviteExpirationSeconds = 3600
)
// UserInviteRecord represents an invitation for a user to set up their account (database model)
type UserInviteRecord struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index;not null"`
Email string `gorm:"index;not null"`
Name string `gorm:"not null"`
Role string `gorm:"not null"`
AutoGroups []string `gorm:"serializer:json"`
HashedToken string `gorm:"index;not null"` // SHA-256 hash of the token (base64 encoded)
ExpiresAt time.Time `gorm:"not null"`
CreatedAt time.Time `gorm:"not null"`
CreatedBy string `gorm:"not null"`
}
// TableName returns the table name for GORM
func (UserInviteRecord) TableName() string {
return "user_invites"
}
// GenerateInviteToken creates a new invite token with the format: nbi_<secret><checksum>
// Returns the hashed token (for storage) and the plain token (to give to the user)
func GenerateInviteToken() (hashedToken string, plainToken string, err error) {
secret, err := b.Random(InviteTokenSecretLength)
if err != nil {
return "", "", fmt.Errorf("failed to generate random secret: %w", err)
}
checksum := crc32.ChecksumIEEE([]byte(secret))
encodedChecksum := base62.Encode(checksum)
// Left-pad with '0' to ensure exactly 6 characters (fmt.Sprintf %s pads with spaces which breaks base62.Decode)
paddedChecksum := encodedChecksum
if len(paddedChecksum) < InviteTokenChecksumLength {
paddedChecksum = strings.Repeat("0", InviteTokenChecksumLength-len(paddedChecksum)) + paddedChecksum
}
plainToken = InviteTokenPrefix + secret + paddedChecksum
hash := sha256.Sum256([]byte(plainToken))
hashedToken = b64.StdEncoding.EncodeToString(hash[:])
return hashedToken, plainToken, nil
}
// HashInviteToken creates a SHA-256 hash of the token (base64 encoded)
func HashInviteToken(token string) string {
hash := sha256.Sum256([]byte(token))
return b64.StdEncoding.EncodeToString(hash[:])
}
// ValidateInviteToken validates the token format and checksum.
// Returns an error if the token is invalid.
func ValidateInviteToken(token string) error {
if len(token) != InviteTokenLength {
return fmt.Errorf("invalid token length")
}
prefix := token[:len(InviteTokenPrefix)]
if prefix != InviteTokenPrefix {
return fmt.Errorf("invalid token prefix")
}
secret := token[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength]
encodedChecksum := token[len(InviteTokenPrefix)+InviteTokenSecretLength:]
verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return fmt.Errorf("checksum decoding failed: %w", err)
}
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return fmt.Errorf("checksum does not match")
}
return nil
}
// IsExpired checks if the invite has expired
func (i *UserInviteRecord) IsExpired() bool {
return time.Now().After(i.ExpiresAt)
}
// UserInvite contains the result of creating or regenerating an invite
type UserInvite struct {
UserInfo *UserInfo
InviteToken string
InviteExpiresAt time.Time
InviteCreatedAt time.Time
}
// UserInviteInfo contains public information about an invite (for unauthenticated endpoint)
type UserInviteInfo struct {
Email string `json:"email"`
Name string `json:"name"`
ExpiresAt time.Time `json:"expires_at"`
Valid bool `json:"valid"`
InvitedBy string `json:"invited_by"`
}
// NewInviteID generates a new invite ID using xid
func NewInviteID() string {
return xid.New().String()
}
// EncryptSensitiveData encrypts the invite's sensitive fields (Email and Name) in place.
func (i *UserInviteRecord) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if i.Email != "" {
i.Email, err = enc.Encrypt(i.Email)
if err != nil {
return fmt.Errorf("encrypt email: %w", err)
}
}
if i.Name != "" {
i.Name, err = enc.Encrypt(i.Name)
if err != nil {
return fmt.Errorf("encrypt name: %w", err)
}
}
return nil
}
// DecryptSensitiveData decrypts the invite's sensitive fields (Email and Name) in place.
func (i *UserInviteRecord) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if i.Email != "" {
i.Email, err = enc.Decrypt(i.Email)
if err != nil {
return fmt.Errorf("decrypt email: %w", err)
}
}
if i.Name != "" {
i.Name, err = enc.Decrypt(i.Name)
if err != nil {
return fmt.Errorf("decrypt name: %w", err)
}
}
return nil
}
// Copy creates a deep copy of the UserInviteRecord
func (i *UserInviteRecord) Copy() *UserInviteRecord {
autoGroups := make([]string, len(i.AutoGroups))
copy(autoGroups, i.AutoGroups)
return &UserInviteRecord{
ID: i.ID,
AccountID: i.AccountID,
Email: i.Email,
Name: i.Name,
Role: i.Role,
AutoGroups: autoGroups,
HashedToken: i.HashedToken,
ExpiresAt: i.ExpiresAt,
CreatedAt: i.CreatedAt,
CreatedBy: i.CreatedBy,
}
}

View File

@@ -1,355 +0,0 @@
package types
import (
"crypto/sha256"
b64 "encoding/base64"
"hash/crc32"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/base62"
"github.com/netbirdio/netbird/util/crypt"
)
func TestUserInviteRecord_TableName(t *testing.T) {
invite := UserInviteRecord{}
assert.Equal(t, "user_invites", invite.TableName())
}
func TestGenerateInviteToken_Success(t *testing.T) {
hashedToken, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.NotEmpty(t, hashedToken)
assert.NotEmpty(t, plainToken)
}
func TestGenerateInviteToken_Length(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.Len(t, plainToken, InviteTokenLength)
}
func TestGenerateInviteToken_Prefix(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.True(t, strings.HasPrefix(plainToken, InviteTokenPrefix))
}
func TestGenerateInviteToken_Hashing(t *testing.T) {
hashedToken, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
expectedHash := sha256.Sum256([]byte(plainToken))
expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:])
assert.Equal(t, expectedHashedToken, hashedToken)
}
func TestGenerateInviteToken_Checksum(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// Extract parts
secret := plainToken[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength]
checksumStr := plainToken[len(InviteTokenPrefix)+InviteTokenSecretLength:]
// Verify checksum
expectedChecksum := crc32.ChecksumIEEE([]byte(secret))
actualChecksum, err := base62.Decode(checksumStr)
require.NoError(t, err)
assert.Equal(t, expectedChecksum, actualChecksum)
}
func TestGenerateInviteToken_Uniqueness(t *testing.T) {
tokens := make(map[string]bool)
for i := 0; i < 100; i++ {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.False(t, tokens[plainToken], "Token should be unique")
tokens[plainToken] = true
}
}
func TestHashInviteToken(t *testing.T) {
token := "nbi_testtoken123456789012345678901234"
hashedToken := HashInviteToken(token)
expectedHash := sha256.Sum256([]byte(token))
expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:])
assert.Equal(t, expectedHashedToken, hashedToken)
}
func TestHashInviteToken_Consistency(t *testing.T) {
token := "nbi_testtoken123456789012345678901234"
hash1 := HashInviteToken(token)
hash2 := HashInviteToken(token)
assert.Equal(t, hash1, hash2)
}
func TestHashInviteToken_DifferentTokens(t *testing.T) {
token1 := "nbi_testtoken123456789012345678901234"
token2 := "nbi_testtoken123456789012345678901235"
hash1 := HashInviteToken(token1)
hash2 := HashInviteToken(token2)
assert.NotEqual(t, hash1, hash2)
}
func TestValidateInviteToken_Success(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
err = ValidateInviteToken(plainToken)
assert.NoError(t, err)
}
func TestValidateInviteToken_InvalidLength(t *testing.T) {
testCases := []struct {
name string
token string
}{
{"empty", ""},
{"too short", "nbi_abc"},
{"too long", "nbi_" + strings.Repeat("a", 50)},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateInviteToken(tc.token)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid token length")
})
}
}
func TestValidateInviteToken_InvalidPrefix(t *testing.T) {
// Create a token with wrong prefix but correct length
token := "xyz_" + strings.Repeat("a", 30) + "000000"
err := ValidateInviteToken(token)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid token prefix")
}
func TestValidateInviteToken_InvalidChecksum(t *testing.T) {
// Create a token with correct format but invalid checksum
token := InviteTokenPrefix + strings.Repeat("a", InviteTokenSecretLength) + "ZZZZZZ"
err := ValidateInviteToken(token)
require.Error(t, err)
assert.Contains(t, err.Error(), "checksum")
}
func TestValidateInviteToken_ModifiedToken(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// Modify one character in the secret part
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
err = ValidateInviteToken(modifiedToken)
require.Error(t, err)
}
func TestUserInviteRecord_IsExpired(t *testing.T) {
t.Run("not expired", func(t *testing.T) {
invite := &UserInviteRecord{
ExpiresAt: time.Now().Add(time.Hour),
}
assert.False(t, invite.IsExpired())
})
t.Run("expired", func(t *testing.T) {
invite := &UserInviteRecord{
ExpiresAt: time.Now().Add(-time.Hour),
}
assert.True(t, invite.IsExpired())
})
t.Run("just expired", func(t *testing.T) {
invite := &UserInviteRecord{
ExpiresAt: time.Now().Add(-time.Second),
}
assert.True(t, invite.IsExpired())
})
}
func TestNewInviteID(t *testing.T) {
id := NewInviteID()
assert.NotEmpty(t, id)
assert.Len(t, id, 20) // xid generates 20 character IDs
}
func TestNewInviteID_Uniqueness(t *testing.T) {
ids := make(map[string]bool)
for i := 0; i < 100; i++ {
id := NewInviteID()
assert.False(t, ids[id], "ID should be unique")
ids[id] = true
}
}
func TestUserInviteRecord_EncryptDecryptSensitiveData(t *testing.T) {
key, err := crypt.GenerateKey()
require.NoError(t, err)
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
require.NoError(t, err)
t.Run("encrypt and decrypt", func(t *testing.T) {
invite := &UserInviteRecord{
ID: "test-invite",
AccountID: "test-account",
Email: "test@example.com",
Name: "Test User",
Role: "user",
}
// Encrypt
err := invite.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
// Verify encrypted values are different from original
assert.NotEqual(t, "test@example.com", invite.Email)
assert.NotEqual(t, "Test User", invite.Name)
// Decrypt
err = invite.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
// Verify decrypted values match original
assert.Equal(t, "test@example.com", invite.Email)
assert.Equal(t, "Test User", invite.Name)
})
t.Run("encrypt empty fields", func(t *testing.T) {
invite := &UserInviteRecord{
ID: "test-invite",
AccountID: "test-account",
Email: "",
Name: "",
Role: "user",
}
err := invite.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", invite.Email)
assert.Equal(t, "", invite.Name)
err = invite.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", invite.Email)
assert.Equal(t, "", invite.Name)
})
t.Run("nil encryptor", func(t *testing.T) {
invite := &UserInviteRecord{
ID: "test-invite",
AccountID: "test-account",
Email: "test@example.com",
Name: "Test User",
Role: "user",
}
err := invite.EncryptSensitiveData(nil)
require.NoError(t, err)
assert.Equal(t, "test@example.com", invite.Email)
assert.Equal(t, "Test User", invite.Name)
err = invite.DecryptSensitiveData(nil)
require.NoError(t, err)
assert.Equal(t, "test@example.com", invite.Email)
assert.Equal(t, "Test User", invite.Name)
})
}
func TestUserInviteRecord_Copy(t *testing.T) {
now := time.Now()
expiresAt := now.Add(72 * time.Hour)
original := &UserInviteRecord{
ID: "invite-id",
AccountID: "account-id",
Email: "test@example.com",
Name: "Test User",
Role: "user",
AutoGroups: []string{"group1", "group2"},
HashedToken: "hashed-token",
ExpiresAt: expiresAt,
CreatedAt: now,
CreatedBy: "creator-id",
}
copied := original.Copy()
// Verify all fields are copied
assert.Equal(t, original.ID, copied.ID)
assert.Equal(t, original.AccountID, copied.AccountID)
assert.Equal(t, original.Email, copied.Email)
assert.Equal(t, original.Name, copied.Name)
assert.Equal(t, original.Role, copied.Role)
assert.Equal(t, original.AutoGroups, copied.AutoGroups)
assert.Equal(t, original.HashedToken, copied.HashedToken)
assert.Equal(t, original.ExpiresAt, copied.ExpiresAt)
assert.Equal(t, original.CreatedAt, copied.CreatedAt)
assert.Equal(t, original.CreatedBy, copied.CreatedBy)
// Verify deep copy of AutoGroups (modifying copy doesn't affect original)
copied.AutoGroups[0] = "modified"
assert.NotEqual(t, original.AutoGroups[0], copied.AutoGroups[0])
assert.Equal(t, "group1", original.AutoGroups[0])
}
func TestUserInviteRecord_Copy_EmptyAutoGroups(t *testing.T) {
original := &UserInviteRecord{
ID: "invite-id",
AccountID: "account-id",
AutoGroups: []string{},
}
copied := original.Copy()
assert.NotNil(t, copied.AutoGroups)
assert.Len(t, copied.AutoGroups, 0)
}
func TestUserInviteRecord_Copy_NilAutoGroups(t *testing.T) {
original := &UserInviteRecord{
ID: "invite-id",
AccountID: "account-id",
AutoGroups: nil,
}
copied := original.Copy()
assert.NotNil(t, copied.AutoGroups)
assert.Len(t, copied.AutoGroups, 0)
}
func TestInviteTokenConstants(t *testing.T) {
// Verify constants are consistent
expectedLength := len(InviteTokenPrefix) + InviteTokenSecretLength + InviteTokenChecksumLength
assert.Equal(t, InviteTokenLength, expectedLength)
assert.Equal(t, 4, len(InviteTokenPrefix))
assert.Equal(t, 30, InviteTokenSecretLength)
assert.Equal(t, 6, InviteTokenChecksumLength)
assert.Equal(t, 40, InviteTokenLength)
assert.Equal(t, 259200, DefaultInviteExpirationSeconds) // 72 hours
assert.Equal(t, 3600, MinInviteExpirationSeconds) // 1 hour
}
func TestGenerateInviteToken_ValidatesOwnOutput(t *testing.T) {
// Generate multiple tokens and ensure they all validate
for i := 0; i < 50; i++ {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
err = ValidateInviteToken(plainToken)
assert.NoError(t, err, "Generated token should always be valid")
}
}
func TestHashInviteToken_MatchesGeneratedHash(t *testing.T) {
hashedToken, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// HashInviteToken should produce the same hash as GenerateInviteToken
rehashedToken := HashInviteToken(plainToken)
assert.Equal(t, hashedToken, rehashedToken)
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"strings"
"time"
"unicode"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
@@ -1454,368 +1453,3 @@ func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, init
return nil
}
// CreateUserInvite creates an invite link for a new user in the embedded IdP.
// The user is NOT created until the invite is accepted.
func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
if !IsEmbeddedIdp(am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
if err := validateUserInvite(invite); err != nil {
return nil, err
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
// Check if user already exists in NetBird DB
existingUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
for _, user := range existingUsers {
if strings.EqualFold(user.Email, invite.Email) {
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
}
}
// Check if invite already exists for this email
existingInvite, err := am.Store.GetUserInviteByEmail(ctx, store.LockingStrengthNone, accountID, invite.Email)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return nil, fmt.Errorf("failed to check existing invites: %w", err)
}
}
if existingInvite != nil {
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
}
// Calculate expiration time
if expiresIn <= 0 {
expiresIn = types.DefaultInviteExpirationSeconds
}
if expiresIn < types.MinInviteExpirationSeconds {
return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour")
}
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second)
// Generate invite token
inviteID := types.NewInviteID()
hashedToken, plainToken, err := types.GenerateInviteToken()
if err != nil {
return nil, fmt.Errorf("failed to generate invite token: %w", err)
}
// Create the invite record (no user created yet)
userInvite := &types.UserInviteRecord{
ID: inviteID,
AccountID: accountID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: invite.AutoGroups,
HashedToken: hashedToken,
ExpiresAt: expiresAt,
CreatedAt: time.Now().UTC(),
CreatedBy: initiatorUserID,
}
if err := am.Store.SaveUserInvite(ctx, userInvite); err != nil {
return nil, err
}
am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkCreated, map[string]any{"email": invite.Email})
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: inviteID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: invite.AutoGroups,
Status: string(types.UserStatusInvited),
Issued: types.UserIssuedAPI,
},
InviteToken: plainToken,
InviteExpiresAt: expiresAt,
}, nil
}
// GetUserInviteInfo retrieves invite information from a token (public endpoint).
func (am *DefaultAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) {
if err := types.ValidateInviteToken(token); err != nil {
return nil, status.Errorf(status.InvalidArgument, "invalid invite token: %v", err)
}
hashedToken := types.HashInviteToken(token)
invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthNone, hashedToken)
if err != nil {
return nil, err
}
// Get the inviter's name
invitedBy := ""
if invite.CreatedBy != "" {
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, invite.CreatedBy)
if err == nil && inviter != nil {
invitedBy = inviter.Name
}
}
return &types.UserInviteInfo{
Email: invite.Email,
Name: invite.Name,
ExpiresAt: invite.ExpiresAt,
Valid: !invite.IsExpired(),
InvitedBy: invitedBy,
}, nil
}
// ListUserInvites returns all invites for an account.
func (am *DefaultAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
if !IsEmbeddedIdp(am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
records, err := am.Store.GetAccountUserInvites(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
invites := make([]*types.UserInvite, 0, len(records))
for _, record := range records {
invites = append(invites, &types.UserInvite{
UserInfo: &types.UserInfo{
ID: record.ID,
Email: record.Email,
Name: record.Name,
Role: record.Role,
AutoGroups: record.AutoGroups,
},
InviteExpiresAt: record.ExpiresAt,
InviteCreatedAt: record.CreatedAt,
})
}
return invites, nil
}
// AcceptUserInvite accepts an invite and creates the user in both IdP and NetBird DB.
func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error {
if !IsEmbeddedIdp(am.idpManager) {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
if password == "" {
return status.Errorf(status.InvalidArgument, "password is required")
}
if err := validatePassword(password); err != nil {
return status.Errorf(status.InvalidArgument, "invalid password: %v", err)
}
if err := types.ValidateInviteToken(token); err != nil {
return status.Errorf(status.InvalidArgument, "invalid invite token: %v", err)
}
hashedToken := types.HashInviteToken(token)
invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthUpdate, hashedToken)
if err != nil {
return err
}
if invite.IsExpired() {
return status.Errorf(status.InvalidArgument, "invite has expired")
}
// Create user in Dex with the provided password
embeddedIdp, ok := am.idpManager.(*idp.EmbeddedIdPManager)
if !ok {
return status.Errorf(status.Internal, "failed to get embedded IdP manager")
}
idpUser, err := embeddedIdp.CreateUserWithPassword(ctx, invite.Email, password, invite.Name)
if err != nil {
return fmt.Errorf("failed to create user in IdP: %w", err)
}
// Create user in NetBird DB
newUser := &types.User{
Id: idpUser.ID,
AccountID: invite.AccountID,
Role: types.StrRoleToUserRole(invite.Role),
AutoGroups: invite.AutoGroups,
Issued: types.UserIssuedAPI,
CreatedAt: time.Now().UTC(),
Email: invite.Email,
Name: invite.Name,
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.SaveUser(ctx, newUser); err != nil {
return fmt.Errorf("failed to save user: %w", err)
}
if err := transaction.DeleteUserInvite(ctx, invite.ID); err != nil {
return fmt.Errorf("failed to delete invite: %w", err)
}
return nil
})
if err != nil {
// Best-effort rollback: delete the IdP user to avoid orphaned records
if deleteErr := embeddedIdp.DeleteUser(ctx, idpUser.ID); deleteErr != nil {
log.WithContext(ctx).WithError(deleteErr).Errorf("failed to rollback IdP user %s after transaction failure", idpUser.ID)
}
return err
}
am.StoreEvent(ctx, newUser.Id, newUser.Id, invite.AccountID, activity.UserInviteLinkAccepted, map[string]any{"email": invite.Email})
return nil
}
// RegenerateUserInvite creates a new invite token for an existing invite, invalidating the previous one.
func (am *DefaultAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
if !IsEmbeddedIdp(am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
// Get existing invite
existingInvite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID)
if err != nil {
return nil, err
}
// Calculate expiration time
if expiresIn <= 0 {
expiresIn = types.DefaultInviteExpirationSeconds
}
if expiresIn < types.MinInviteExpirationSeconds {
return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour")
}
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second)
// Generate new invite token
hashedToken, plainToken, err := types.GenerateInviteToken()
if err != nil {
return nil, fmt.Errorf("failed to generate invite token: %w", err)
}
// Update existing invite with new token and expiration
existingInvite.HashedToken = hashedToken
existingInvite.ExpiresAt = expiresAt
existingInvite.CreatedBy = initiatorUserID
err = am.Store.SaveUserInvite(ctx, existingInvite)
if err != nil {
return nil, err
}
am.StoreEvent(ctx, initiatorUserID, existingInvite.ID, accountID, activity.UserInviteLinkRegenerated, map[string]any{"email": existingInvite.Email})
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: existingInvite.ID,
Email: existingInvite.Email,
Name: existingInvite.Name,
Role: existingInvite.Role,
AutoGroups: existingInvite.AutoGroups,
Status: string(types.UserStatusInvited),
Issued: types.UserIssuedAPI,
},
InviteToken: plainToken,
InviteExpiresAt: expiresAt,
}, nil
}
// DeleteUserInvite deletes an existing invite by ID.
func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
if !IsEmbeddedIdp(am.idpManager) {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
invite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID)
if err != nil {
return err
}
if err := am.Store.DeleteUserInvite(ctx, inviteID); err != nil {
return err
}
am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkDeleted, map[string]any{"email": invite.Email})
return nil
}
const minPasswordLength = 8
// validatePassword checks password strength requirements:
// - Minimum 8 characters
// - At least 1 digit
// - At least 1 uppercase letter
// - At least 1 special character
func validatePassword(password string) error {
if len(password) < minPasswordLength {
return errors.New("password must be at least 8 characters long")
}
var hasDigit, hasUpper, hasSpecial bool
for _, c := range password {
switch {
case unicode.IsDigit(c):
hasDigit = true
case unicode.IsUpper(c):
hasUpper = true
case !unicode.IsLetter(c) && !unicode.IsDigit(c):
hasSpecial = true
}
}
var missing []string
if !hasDigit {
missing = append(missing, "one digit")
}
if !hasUpper {
missing = append(missing, "one uppercase letter")
}
if !hasSpecial {
missing = append(missing, "one special character")
}
if len(missing) > 0 {
return errors.New("password must contain at least " + strings.Join(missing, ", "))
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -72,8 +72,8 @@ var (
func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator {
keys, err := getPemKeys(keysLocation)
if err != nil && !strings.Contains(keysLocation, "localhost") {
log.WithField("keysLocation", keysLocation).Warnf("could not get keys from location: %s, it will try again on the next http request", err)
if err != nil {
log.WithField("keysLocation", keysLocation).Errorf("could not get keys from location: %s", err)
}
return &Validator{

View File

@@ -488,171 +488,6 @@ components:
- role
- auto_groups
- is_service_user
UserInviteCreateRequest:
type: object
description: Request to create a user invite link
properties:
email:
description: User's email address
type: string
example: user@example.com
name:
description: User's full name
type: string
example: John Doe
role:
description: User's NetBird account role
type: string
example: user
auto_groups:
description: Group IDs to auto-assign to peers registered by this user
type: array
items:
type: string
example: ch8i4ug6lnn4g9hqv7m0
expires_in:
description: Invite expiration time in seconds (default 72 hours)
type: integer
example: 259200
required:
- email
- name
- role
- auto_groups
UserInvite:
type: object
description: A user invite
properties:
id:
description: Invite ID
type: string
example: d5p7eedra0h0lt6f59hg
email:
description: User's email address
type: string
example: user@example.com
name:
description: User's full name
type: string
example: John Doe
role:
description: User's NetBird account role
type: string
example: user
auto_groups:
description: Group IDs to auto-assign to peers registered by this user
type: array
items:
type: string
example: ch8i4ug6lnn4g9hqv7m0
expires_at:
description: Invite expiration time
type: string
format: date-time
example: "2024-01-25T10:00:00Z"
created_at:
description: Invite creation time
type: string
format: date-time
example: "2024-01-22T10:00:00Z"
expired:
description: Whether the invite has expired
type: boolean
example: false
invite_token:
description: The invite link to be shared with the user. Only returned when the invite is created or regenerated.
type: string
example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123
required:
- id
- email
- name
- role
- auto_groups
- expires_at
- created_at
- expired
UserInviteInfo:
type: object
description: Public information about an invite
properties:
email:
description: User's email address
type: string
example: user@example.com
name:
description: User's full name
type: string
example: John Doe
expires_at:
description: Invite expiration time
type: string
format: date-time
example: "2024-01-25T10:00:00Z"
valid:
description: Whether the invite is still valid (not expired)
type: boolean
example: true
invited_by:
description: Name of the user who sent the invite
type: string
example: Admin User
required:
- email
- name
- expires_at
- valid
- invited_by
UserInviteAcceptRequest:
type: object
description: Request to accept an invite and set password
properties:
password:
description: >-
The password the user wants to set. Must be at least 8 characters long
and contain at least one uppercase letter, one digit, and one special
character (any character that is not a letter or digit, including spaces).
type: string
format: password
minLength: 8
pattern: '^(?=.*[0-9])(?=.*[A-Z])(?=.*[^a-zA-Z0-9]).{8,}$'
example: SecurePass123!
required:
- password
UserInviteAcceptResponse:
type: object
description: Response after accepting an invite
properties:
success:
description: Whether the invite was accepted successfully
type: boolean
example: true
required:
- success
UserInviteRegenerateRequest:
type: object
description: Request to regenerate an invite link
properties:
expires_in:
description: Invite expiration time in seconds (default 72 hours)
type: integer
example: 259200
UserInviteRegenerateResponse:
type: object
description: Response after regenerating an invite
properties:
invite_token:
description: The new invite token
type: string
example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123
invite_expires_at:
description: New invite expiration time
type: string
format: date-time
example: "2024-01-28T10:00:00Z"
required:
- invite_token
- invite_expires_at
PeerMinimum:
type: object
properties:
@@ -2236,8 +2071,7 @@ components:
"dns.zone.create", "dns.zone.update", "dns.zone.delete",
"dns.zone.record.create", "dns.zone.record.update", "dns.zone.record.delete",
"peer.job.create",
"user.password.change",
"user.invite.link.create", "user.invite.link.accept", "user.invite.link.regenerate", "user.invite.link.delete"
"user.password.change"
]
example: route.add
initiator_id:
@@ -2808,29 +2642,6 @@ components:
required:
- user_id
- email
InstanceVersionInfo:
type: object
description: Version information for NetBird components
properties:
management_current_version:
description: The current running version of the management server
type: string
example: "0.35.0"
dashboard_available_version:
description: The latest available version of the dashboard (from GitHub releases)
type: string
example: "2.10.0"
management_available_version:
description: The latest available version of the management server (from GitHub releases)
type: string
example: "0.35.0"
management_update_available:
description: Indicates if a newer management version is available
type: boolean
example: true
required:
- management_current_version
- management_update_available
responses:
not_found:
description: Resource not found
@@ -2883,27 +2694,6 @@ paths:
$ref: '#/components/schemas/InstanceStatus'
'500':
"$ref": "#/components/responses/internal_error"
/api/instance/version:
get:
summary: Get Version Info
description: Returns version information for NetBird components including the current management server version and latest available versions from GitHub.
tags: [ Instance ]
security:
- BearerAuth: []
- TokenAuth: []
responses:
'200':
description: Version information
content:
application/json:
schema:
$ref: '#/components/schemas/InstanceVersionInfo'
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/setup:
post:
summary: Setup Instance
@@ -3522,210 +3312,6 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/invites:
get:
summary: List user invites
description: Lists all pending invites for the account. Only available when embedded IdP is enabled.
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: List of invites
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/UserInvite'
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'412':
description: Precondition failed - embedded IdP is not enabled
content: { }
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a user invite
description: Creates an invite link for a new user. Only available when embedded IdP is enabled. The user is not created until they accept the invite.
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: User invite information
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteCreateRequest'
responses:
'200':
description: Invite created successfully
content:
application/json:
schema:
$ref: '#/components/schemas/UserInvite'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'409':
description: User or invite already exists
content: { }
'412':
description: Precondition failed - embedded IdP is not enabled
content: { }
'422':
"$ref": "#/components/responses/validation_failed"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/invites/{inviteId}:
delete:
summary: Delete a user invite
description: Deletes a pending invite. Only available when embedded IdP is enabled.
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: inviteId
required: true
schema:
type: string
description: The ID of the invite to delete
responses:
'200':
description: Invite deleted successfully
content: { }
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
description: Invite not found
content: { }
'412':
description: Precondition failed - embedded IdP is not enabled
content: { }
'500':
"$ref": "#/components/responses/internal_error"
/api/users/invites/{inviteId}/regenerate:
post:
summary: Regenerate a user invite
description: Regenerates an invite link for an existing invite. Invalidates the previous token and creates a new one.
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: inviteId
required: true
schema:
type: string
description: The ID of the invite to regenerate
requestBody:
description: Regenerate options
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteRegenerateRequest'
responses:
'200':
description: Invite regenerated successfully
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteRegenerateResponse'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
description: Invite not found
content: { }
'412':
description: Precondition failed - embedded IdP is not enabled
content: { }
'422':
"$ref": "#/components/responses/validation_failed"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/invites/{token}:
get:
summary: Get invite information
description: Retrieves public information about an invite. This endpoint is unauthenticated and protected by the token itself.
tags: [ Users ]
security: []
parameters:
- in: path
name: token
required: true
schema:
type: string
description: The invite token
responses:
'200':
description: Invite information
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteInfo'
'400':
"$ref": "#/components/responses/bad_request"
'404':
description: Invite not found or invalid token
content: { }
'500':
"$ref": "#/components/responses/internal_error"
/api/users/invites/{token}/accept:
post:
summary: Accept an invite
description: Accepts an invite and creates the user with the provided password. This endpoint is unauthenticated and protected by the token itself.
tags: [ Users ]
security: []
parameters:
- in: path
name: token
required: true
schema:
type: string
description: The invite token
requestBody:
description: Password to set for the new user
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteAcceptRequest'
responses:
'200':
description: Invite accepted successfully
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteAcceptResponse'
'400':
"$ref": "#/components/responses/bad_request"
'404':
description: Invite not found or invalid token
content: { }
'412':
description: Precondition failed - embedded IdP is not enabled or invite expired
content: { }
'422':
"$ref": "#/components/responses/validation_failed"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers:
get:
summary: List all Peers

View File

@@ -123,10 +123,6 @@ const (
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
EventActivityCodeUserInvite EventActivityCode = "user.invite"
EventActivityCodeUserInviteLinkAccept EventActivityCode = "user.invite.link.accept"
EventActivityCodeUserInviteLinkCreate EventActivityCode = "user.invite.link.create"
EventActivityCodeUserInviteLinkDelete EventActivityCode = "user.invite.link.delete"
EventActivityCodeUserInviteLinkRegenerate EventActivityCode = "user.invite.link.regenerate"
EventActivityCodeUserJoin EventActivityCode = "user.join"
EventActivityCodeUserPasswordChange EventActivityCode = "user.password.change"
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
@@ -874,21 +870,6 @@ type InstanceStatus struct {
SetupRequired bool `json:"setup_required"`
}
// InstanceVersionInfo Version information for NetBird components
type InstanceVersionInfo struct {
// DashboardAvailableVersion The latest available version of the dashboard (from GitHub releases)
DashboardAvailableVersion *string `json:"dashboard_available_version,omitempty"`
// ManagementAvailableVersion The latest available version of the management server (from GitHub releases)
ManagementAvailableVersion *string `json:"management_available_version,omitempty"`
// ManagementCurrentVersion The current running version of the management server
ManagementCurrentVersion string `json:"management_current_version"`
// ManagementUpdateAvailable Indicates if a newer management version is available
ManagementUpdateAvailable bool `json:"management_update_available"`
}
// JobRequest defines model for JobRequest.
type JobRequest struct {
Workload WorkloadRequest `json:"workload"`
@@ -2185,99 +2166,6 @@ type UserCreateRequest struct {
Role string `json:"role"`
}
// UserInvite A user invite
type UserInvite struct {
// AutoGroups Group IDs to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// CreatedAt Invite creation time
CreatedAt time.Time `json:"created_at"`
// Email User's email address
Email string `json:"email"`
// Expired Whether the invite has expired
Expired bool `json:"expired"`
// ExpiresAt Invite expiration time
ExpiresAt time.Time `json:"expires_at"`
// Id Invite ID
Id string `json:"id"`
// InviteToken The invite link to be shared with the user. Only returned when the invite is created or regenerated.
InviteToken *string `json:"invite_token,omitempty"`
// Name User's full name
Name string `json:"name"`
// Role User's NetBird account role
Role string `json:"role"`
}
// UserInviteAcceptRequest Request to accept an invite and set password
type UserInviteAcceptRequest struct {
// Password The password the user wants to set. Must be at least 8 characters long and contain at least one uppercase letter, one digit, and one special character (any character that is not a letter or digit, including spaces).
Password string `json:"password"`
}
// UserInviteAcceptResponse Response after accepting an invite
type UserInviteAcceptResponse struct {
// Success Whether the invite was accepted successfully
Success bool `json:"success"`
}
// UserInviteCreateRequest Request to create a user invite link
type UserInviteCreateRequest struct {
// AutoGroups Group IDs to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// Email User's email address
Email string `json:"email"`
// ExpiresIn Invite expiration time in seconds (default 72 hours)
ExpiresIn *int `json:"expires_in,omitempty"`
// Name User's full name
Name string `json:"name"`
// Role User's NetBird account role
Role string `json:"role"`
}
// UserInviteInfo Public information about an invite
type UserInviteInfo struct {
// Email User's email address
Email string `json:"email"`
// ExpiresAt Invite expiration time
ExpiresAt time.Time `json:"expires_at"`
// InvitedBy Name of the user who sent the invite
InvitedBy string `json:"invited_by"`
// Name User's full name
Name string `json:"name"`
// Valid Whether the invite is still valid (not expired)
Valid bool `json:"valid"`
}
// UserInviteRegenerateRequest Request to regenerate an invite link
type UserInviteRegenerateRequest struct {
// ExpiresIn Invite expiration time in seconds (default 72 hours)
ExpiresIn *int `json:"expires_in,omitempty"`
}
// UserInviteRegenerateResponse Response after regenerating an invite
type UserInviteRegenerateResponse struct {
// InviteExpiresAt New invite expiration time
InviteExpiresAt time.Time `json:"invite_expires_at"`
// InviteToken The new invite token
InviteToken string `json:"invite_token"`
}
// UserPermissions defines model for UserPermissions.
type UserPermissions struct {
// IsRestricted Indicates whether this User's Peers view is restricted
@@ -2530,15 +2418,6 @@ type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest
// PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType.
type PostApiUsersJSONRequestBody = UserCreateRequest
// PostApiUsersInvitesJSONRequestBody defines body for PostApiUsersInvites for application/json ContentType.
type PostApiUsersInvitesJSONRequestBody = UserInviteCreateRequest
// PostApiUsersInvitesInviteIdRegenerateJSONRequestBody defines body for PostApiUsersInvitesInviteIdRegenerate for application/json ContentType.
type PostApiUsersInvitesInviteIdRegenerateJSONRequestBody = UserInviteRegenerateRequest
// PostApiUsersInvitesTokenAcceptJSONRequestBody defines body for PostApiUsersInvitesTokenAccept for application/json ContentType.
type PostApiUsersInvitesTokenAcceptJSONRequestBody = UserInviteAcceptRequest
// PutApiUsersUserIdJSONRequestBody defines body for PutApiUsersUserId for application/json ContentType.
type PutApiUsersUserIdJSONRequestBody = UserRequest