diff --git a/client/android/login.go b/client/android/login.go index 4d4c7a650..a9422cdbf 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -3,15 +3,7 @@ package android import ( "context" "fmt" - "time" - "github.com/cenkalti/backoff/v4" - log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - gstatus "google.golang.org/grpc/status" - - "github.com/netbirdio/netbird/client/cmd" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" @@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { } func (a *Auth) saveConfigIfSSOSupported() (bool, error) { - supportsSSO := true - err := a.withBackOff(a.ctx, func() (err error) { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { - _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - s, ok := gstatus.FromError(err) - if !ok { - return err - } - if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented { - supportsSSO = false - err = nil - } + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return false, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() - return err - } - - return err - }) + supportsSSO, err := authClient.IsSSOSupported(a.ctx) + if err != nil { + return false, fmt.Errorf("failed to check SSO support: %v", err) + } if !supportsSSO { return false, nil } - if err != nil { - return false, fmt.Errorf("backoff cycle failed: %v", err) - } - err = profilemanager.WriteOutConfig(a.cfgPath, a.config) return true, err } @@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK } func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + //nolint ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) - - err := a.withBackOff(a.ctx, func() error { - backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "") - if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) { - // we got an answer from management, exit backoff earlier - return backoff.Permanent(backoffErr) - } - return backoffErr - }) + err, _ = authClient.Login(ctxWithValues, setupKey, "") if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("login failed: %v", err) } return profilemanager.WriteOutConfig(a.cfgPath, a.config) @@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT } func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error { - var needsLogin bool + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() // check if we need to generate JWT token - err := a.withBackOff(a.ctx, func() (err error) { - needsLogin, err = internal.IsLoginRequired(a.ctx, a.config) - return - }) + needsLogin, err := authClient.IsLoginRequired(a.ctx) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("failed to check login requirement: %v", err) } jwtToken := "" if needsLogin { - tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV) + tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } jwtToken = tokenInfo.GetTokenToUse() } - err = a.withBackOff(a.ctx, func() error { - err := internal.Login(a.ctx, a.config, "", jwtToken) - - if err == nil { - go urlOpener.OnLoginSuccess() - } - - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - return nil - } - return err - }) + err, _ = authClient.Login(a.ctx, "", jwtToken) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("login failed: %v", err) } + go urlOpener.OnLoginSuccess() + return nil } -func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "") +func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) { + oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get OAuth flow: %v", err) } flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO()) @@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode) - waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second - waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout) - defer cancel() - tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) + tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo) if err != nil { return nil, fmt.Errorf("waiting for browser login failed: %v", err) } return &tokenInfo, nil } - -func (a *Auth) withBackOff(ctx context.Context, bf func() error) error { - return backoff.RetryNotify( - bf, - backoff.WithContext(cmd.CLIBackOffSettings, ctx), - func(err error, duration time.Duration) { - log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err) - }) -} diff --git a/client/cmd/debug.go b/client/cmd/debug.go index bbb0ef0d6..e480df4d7 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -219,11 +219,33 @@ func runForDuration(cmd *cobra.Command, args []string) error { time.Sleep(3 * time.Second) + cpuProfilingStarted := false + if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil { + cmd.PrintErrf("Failed to start CPU profiling: %v\n", err) + } else { + cpuProfilingStarted = true + defer func() { + if cpuProfilingStarted { + if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil { + cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err) + } + } + }() + } + if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { return waitErr } cmd.Println("\nDuration completed") + if cpuProfilingStarted { + if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil { + cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err) + } else { + cpuProfilingStarted = false + } + } + cmd.Println("Creating debug bundle...") request := &proto.DebugBundleRequest{ @@ -353,6 +375,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c StatusRecorder: recorder, SyncResponse: syncResponse, LogPath: logFilePath, + CPUProfile: nil, }, debug.BundleConfig{ IncludeSystemInfo: true, diff --git a/client/cmd/login.go b/client/cmd/login.go index 57c010571..64b45e557 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -7,7 +7,6 @@ import ( "os/user" "runtime" "strings" - "time" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -277,18 +276,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo } func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error { + authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + needsLogin := false - err := WithBackOff(func() error { - err := internal.Login(ctx, config, "", "") - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - needsLogin = true - return nil - } - return err - }) - if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + err, isAuthError := authClient.Login(ctx, "", "") + if isAuthError { + needsLogin = true + } else if err != nil { + return fmt.Errorf("login check failed: %v", err) } jwtToken := "" @@ -300,23 +300,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman jwtToken = tokenInfo.GetTokenToUse() } - var lastError error - - err = WithBackOff(func() error { - err := internal.Login(ctx, config, setupKey, jwtToken) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - lastError = err - return nil - } - return err - }) - - if lastError != nil { - return fmt.Errorf("login failed: %v", lastError) - } - + err, _ = authClient.Login(ctx, setupKey, jwtToken) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("login failed: %v", err) } return nil @@ -344,11 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser) - waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second - waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout) - defer c() - - tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) + tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo) if err != nil { return nil, fmt.Errorf("waiting for browser login failed: %v", err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index 8bbbef0f2..e266aae28 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" sshcommon "github.com/netbirdio/netbird/client/ssh" @@ -176,7 +177,13 @@ func (c *Client) Start(startCtx context.Context) error { // nolint:staticcheck ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) - if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil { + authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config) + if err != nil { + return fmt.Errorf("create auth client: %w", err) + } + defer authClient.Close() + + if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil { return fmt.Errorf("login: %w", err) } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 2563a9052..716385705 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -83,6 +83,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { return fmt.Errorf("acl manager init: %w", err) } + if err := m.initNoTrackChain(); err != nil { + return fmt.Errorf("init notrack chain: %w", err) + } + // persist early to ensure cleanup of chains go func() { if err := stateManager.PersistState(context.Background()); err != nil { @@ -177,6 +181,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { var merr *multierror.Error + if err := m.cleanupNoTrackChain(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err)) + } + if err := m.aclMgr.Reset(); err != nil { merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) } @@ -277,6 +285,125 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } +const ( + chainNameRaw = "NETBIRD-RAW" + chainOUTPUT = "OUTPUT" + tableRaw = "raw" +) + +// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic. +// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which +// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark). +// +// Traffic flows that need NOTRACK: +// +// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite) +// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort +// Matched by: sport=wgPort +// +// 2. Egress: Proxy -> WireGuard (via raw socket) +// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort +// Matched by: dport=wgPort +// +// 3. Ingress: Packets to WireGuard +// dst=127.0.0.1:wgPort +// Matched by: dport=wgPort +// +// 4. Ingress: Packets to proxy (after eBPF rewrite) +// dst=127.0.0.1:proxyPort +// Matched by: dport=proxyPort +// +// Rules are cleaned up when the firewall manager is closed. +func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + wgPortStr := fmt.Sprintf("%d", wgPort) + proxyPortStr := fmt.Sprintf("%d", proxyPort) + + // Egress rules: match outgoing loopback UDP packets + outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"} + if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil { + return fmt.Errorf("add output sport notrack rule: %w", err) + } + + outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"} + if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil { + return fmt.Errorf("add output dport notrack rule: %w", err) + } + + // Ingress rules: match incoming loopback UDP packets + preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"} + if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil { + return fmt.Errorf("add prerouting wg notrack rule: %w", err) + } + + preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"} + if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil { + return fmt.Errorf("add prerouting proxy notrack rule: %w", err) + } + + log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort) + return nil +} + +func (m *Manager) initNoTrackChain() error { + if err := m.cleanupNoTrackChain(); err != nil { + log.Debugf("cleanup notrack chain: %v", err) + } + + if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil { + return fmt.Errorf("create chain: %w", err) + } + + jumpRule := []string{"-j", chainNameRaw} + + if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil { + if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil { + log.Debugf("delete orphan chain: %v", delErr) + } + return fmt.Errorf("add output jump rule: %w", err) + } + + if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil { + if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil { + log.Debugf("delete output jump rule: %v", delErr) + } + if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil { + log.Debugf("delete orphan chain: %v", delErr) + } + return fmt.Errorf("add prerouting jump rule: %w", err) + } + + return nil +} + +func (m *Manager) cleanupNoTrackChain() error { + exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw) + if err != nil { + return fmt.Errorf("check chain exists: %w", err) + } + if !exists { + return nil + } + + jumpRule := []string{"-j", chainNameRaw} + + if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil { + return fmt.Errorf("remove output jump rule: %w", err) + } + + if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil { + return fmt.Errorf("remove prerouting jump rule: %w", err) + } + + if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil { + return fmt.Errorf("clear and delete chain: %w", err) + } + + return nil +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 72e6a5c68..3511a5463 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -168,6 +168,10 @@ type Manager interface { // RemoveInboundDNAT removes inbound DNAT rule RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + + // SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic. + // This prevents conntrack from interfering with WireGuard proxy communication. + SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index bd19f1067..acf482f86 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -12,6 +12,7 @@ import ( "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" @@ -48,8 +49,10 @@ type Manager struct { rConn *nftables.Conn wgIface iFaceMapper - router *router - aclManager *AclManager + router *router + aclManager *AclManager + notrackOutputChain *nftables.Chain + notrackPreroutingChain *nftables.Chain } // Create nftables firewall manager @@ -91,6 +94,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { return fmt.Errorf("acl manager init: %w", err) } + if err := m.initNoTrackChains(workTable); err != nil { + return fmt.Errorf("init notrack chains: %w", err) + } + stateManager.RegisterState(&ShutdownState{}) // We only need to record minimal interface state for potential recreation. @@ -288,7 +295,15 @@ func (m *Manager) Flush() error { m.mutex.Lock() defer m.mutex.Unlock() - return m.aclManager.Flush() + if err := m.aclManager.Flush(); err != nil { + return err + } + + if err := m.refreshNoTrackChains(); err != nil { + log.Errorf("failed to refresh notrack chains: %v", err) + } + + return nil } // AddDNATRule adds a DNAT rule @@ -331,6 +346,176 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } +const ( + chainNameRawOutput = "netbird-raw-out" + chainNameRawPrerouting = "netbird-raw-pre" +) + +// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic. +// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which +// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark). +// +// Traffic flows that need NOTRACK: +// +// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite) +// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort +// Matched by: sport=wgPort +// +// 2. Egress: Proxy -> WireGuard (via raw socket) +// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort +// Matched by: dport=wgPort +// +// 3. Ingress: Packets to WireGuard +// dst=127.0.0.1:wgPort +// Matched by: dport=wgPort +// +// 4. Ingress: Packets to proxy (after eBPF rewrite) +// dst=127.0.0.1:proxyPort +// Matched by: dport=proxyPort +// +// Rules are cleaned up when the firewall manager is closed. +func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil { + return fmt.Errorf("notrack chains not initialized") + } + + proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort) + wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort) + loopback := []byte{127, 0, 0, 1} + + // Egress rules: match outgoing loopback UDP packets + m.rConn.AddRule(&nftables.Rule{ + Table: m.notrackOutputChain.Table, + Chain: m.notrackOutputChain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort + &expr.Counter{}, + &expr.Notrack{}, + }, + }) + m.rConn.AddRule(&nftables.Rule{ + Table: m.notrackOutputChain.Table, + Chain: m.notrackOutputChain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort + &expr.Counter{}, + &expr.Notrack{}, + }, + }) + + // Ingress rules: match incoming loopback UDP packets + m.rConn.AddRule(&nftables.Rule{ + Table: m.notrackPreroutingChain.Table, + Chain: m.notrackPreroutingChain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort + &expr.Counter{}, + &expr.Notrack{}, + }, + }) + m.rConn.AddRule(&nftables.Rule{ + Table: m.notrackPreroutingChain.Table, + Chain: m.notrackPreroutingChain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort + &expr.Counter{}, + &expr.Notrack{}, + }, + }) + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf("flush notrack rules: %w", err) + } + + log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort) + return nil +} + +func (m *Manager) initNoTrackChains(table *nftables.Table) error { + m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{ + Name: chainNameRawOutput, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityRaw, + }) + + m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{ + Name: chainNameRawPrerouting, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityRaw, + }) + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf("flush chain creation: %w", err) + } + + return nil +} + +func (m *Manager) refreshNoTrackChains() error { + chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return fmt.Errorf("list chains: %w", err) + } + + tableName := getTableName() + for _, c := range chains { + if c.Table.Name != tableName { + continue + } + switch c.Name { + case chainNameRawOutput: + m.notrackOutputChain = c + case chainNameRawPrerouting: + m.notrackPreroutingChain = c + } + } + + return nil +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 8caa1a0ad..aacc4ca1c 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -570,6 +570,14 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } +// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic. +func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error { + if m.nativeFirewall == nil { + return nil + } + return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort) +} + // UpdateSet updates the rule destinations associated with the given set // by merging the existing prefixes with the new ones, then deduplicating. func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { diff --git a/client/iface/bind/dual_stack_conn.go b/client/iface/bind/dual_stack_conn.go new file mode 100644 index 000000000..061016ecc --- /dev/null +++ b/client/iface/bind/dual_stack_conn.go @@ -0,0 +1,169 @@ +package bind + +import ( + "errors" + "net" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" +) + +var ( + errNoIPv4Conn = errors.New("no IPv4 connection available") + errNoIPv6Conn = errors.New("no IPv6 connection available") + errInvalidAddr = errors.New("invalid address type") +) + +// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes +// to the appropriate connection based on the destination address. +// ReadFrom is not used in the hot path - ICEBind receives packets via +// BatchReader.ReadBatch() directly. This is only used by udpMux for sending. +type DualStackPacketConn struct { + ipv4Conn net.PacketConn + ipv6Conn net.PacketConn + + readFromWarn sync.Once +} + +// NewDualStackPacketConn creates a new dual-stack packet connection. +func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn { + return &DualStackPacketConn{ + ipv4Conn: ipv4Conn, + ipv6Conn: ipv6Conn, + } +} + +// ReadFrom reads from the available connection (preferring IPv4). +// NOTE: This method is NOT used in the data path. ICEBind receives packets via +// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient. +// This implementation exists only to satisfy the net.PacketConn interface for the udpMux, +// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom() +// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path. +func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + d.readFromWarn.Do(func() { + log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path") + }) + + if d.ipv4Conn != nil { + return d.ipv4Conn.ReadFrom(b) + } + if d.ipv6Conn != nil { + return d.ipv6Conn.ReadFrom(b) + } + return 0, nil, net.ErrClosed +} + +// WriteTo writes to the appropriate connection based on the address type. +func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return 0, &net.OpError{ + Op: "write", + Net: "udp", + Addr: addr, + Err: errInvalidAddr, + } + } + + if udpAddr.IP.To4() == nil { + if d.ipv6Conn != nil { + return d.ipv6Conn.WriteTo(b, addr) + } + return 0, &net.OpError{ + Op: "write", + Net: "udp6", + Addr: addr, + Err: errNoIPv6Conn, + } + } + + if d.ipv4Conn != nil { + return d.ipv4Conn.WriteTo(b, addr) + } + return 0, &net.OpError{ + Op: "write", + Net: "udp4", + Addr: addr, + Err: errNoIPv4Conn, + } +} + +// Close closes both connections. +func (d *DualStackPacketConn) Close() error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.Close(); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.Close(); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} + +// LocalAddr returns the local address of the IPv4 connection if available, +// otherwise the IPv6 connection. +func (d *DualStackPacketConn) LocalAddr() net.Addr { + if d.ipv4Conn != nil { + return d.ipv4Conn.LocalAddr() + } + if d.ipv6Conn != nil { + return d.ipv6Conn.LocalAddr() + } + return nil +} + +// SetDeadline sets the deadline for both connections. +func (d *DualStackPacketConn) SetDeadline(t time.Time) error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.SetDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.SetDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} + +// SetReadDeadline sets the read deadline for both connections. +func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.SetReadDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.SetReadDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} + +// SetWriteDeadline sets the write deadline for both connections. +func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.SetWriteDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.SetWriteDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} diff --git a/client/iface/bind/dual_stack_conn_bench_test.go b/client/iface/bind/dual_stack_conn_bench_test.go new file mode 100644 index 000000000..940c44966 --- /dev/null +++ b/client/iface/bind/dual_stack_conn_bench_test.go @@ -0,0 +1,119 @@ +package bind + +import ( + "net" + "testing" +) + +var ( + ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345} + ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345} + payload = make([]byte, 1200) +) + +func BenchmarkWriteTo_DirectUDPConn(b *testing.B) { + conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = conn.WriteTo(payload, ipv4Addr) + } +} + +func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) { + conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn.Close() + + ds := NewDualStackPacketConn(conn, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv4Addr) + } +} + +func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) { + conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn.Close() + + ds := NewDualStackPacketConn(nil, conn) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv6Addr) + } +} + +func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) { + conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn4.Close() + + conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn6.Close() + + ds := NewDualStackPacketConn(conn4, conn6) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv4Addr) + } +} + +func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) { + conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn4.Close() + + conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn6.Close() + + ds := NewDualStackPacketConn(conn4, conn6) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv6Addr) + } +} + +func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) { + conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn4.Close() + + conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn6.Close() + + ds := NewDualStackPacketConn(conn4, conn6) + addrs := []net.Addr{ipv4Addr, ipv6Addr} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, addrs[i&1]) + } +} diff --git a/client/iface/bind/dual_stack_conn_test.go b/client/iface/bind/dual_stack_conn_test.go new file mode 100644 index 000000000..3007d907f --- /dev/null +++ b/client/iface/bind/dual_stack_conn_test.go @@ -0,0 +1,191 @@ +package bind + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) { + ipv4Conn := &mockPacketConn{network: "udp4"} + ipv6Conn := &mockPacketConn{network: "udp6"} + dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn) + + tests := []struct { + name string + addr *net.UDPAddr + wantSocket string + }{ + { + name: "IPv4 address", + addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}, + wantSocket: "udp4", + }, + { + name: "IPv6 address", + addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}, + wantSocket: "udp6", + }, + { + name: "IPv4-mapped IPv6 goes to IPv4", + addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234}, + wantSocket: "udp4", + }, + { + name: "IPv4 loopback", + addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234}, + wantSocket: "udp4", + }, + { + name: "IPv6 loopback", + addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234}, + wantSocket: "udp6", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ipv4Conn.writeCount = 0 + ipv6Conn.writeCount = 0 + + n, err := dualStack.WriteTo([]byte("test"), tt.addr) + require.NoError(t, err) + assert.Equal(t, 4, n) + + if tt.wantSocket == "udp4" { + assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4") + assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6") + } else { + assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4") + assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6") + } + }) + } +} + +func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) { + dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil) + + // IPv4 works + _, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}) + require.NoError(t, err) + + // IPv6 fails + _, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no IPv6 connection") +} + +func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) { + dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"}) + + // IPv6 works + _, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}) + require.NoError(t, err) + + // IPv4 fails + _, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no IPv4 connection") +} + +// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom +// only reads from one socket (IPv4 preferred). This is fine because the actual +// receive path uses wireguard-go's BatchReader directly, not ReadFrom. +func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) { + ipv4Conn := &mockPacketConn{ + network: "udp4", + readData: []byte("from ipv4"), + readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}, + } + ipv6Conn := &mockPacketConn{ + network: "udp6", + readData: []byte("from ipv6"), + readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}, + } + + dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn) + + buf := make([]byte, 100) + n, addr, err := dualStack.ReadFrom(buf) + + require.NoError(t, err) + // reads from IPv4 (preferred) - this is expected behavior + assert.Equal(t, "from ipv4", string(buf[:n])) + assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String()) +} + +func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) { + ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820} + ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820} + + tests := []struct { + name string + ipv4 net.PacketConn + ipv6 net.PacketConn + wantAddr net.Addr + }{ + { + name: "both available returns IPv4", + ipv4: &mockPacketConn{localAddr: ipv4Addr}, + ipv6: &mockPacketConn{localAddr: ipv6Addr}, + wantAddr: ipv4Addr, + }, + { + name: "IPv4 only", + ipv4: &mockPacketConn{localAddr: ipv4Addr}, + ipv6: nil, + wantAddr: ipv4Addr, + }, + { + name: "IPv6 only", + ipv4: nil, + ipv6: &mockPacketConn{localAddr: ipv6Addr}, + wantAddr: ipv6Addr, + }, + { + name: "neither returns nil", + ipv4: nil, + ipv6: nil, + wantAddr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6) + assert.Equal(t, tt.wantAddr, dualStack.LocalAddr()) + }) + } +} + +// mock + +type mockPacketConn struct { + network string + writeCount int + readData []byte + readAddr net.Addr + localAddr net.Addr +} + +func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + if m.readData != nil { + return copy(b, m.readData), m.readAddr, nil + } + return 0, nil, nil +} + +func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + m.writeCount++ + return len(b), nil +} + +func (m *mockPacketConn) Close() error { return nil } +func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr } +func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil } +func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 0957d2dd5..bf79ecd79 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -14,7 +14,6 @@ import ( "github.com/pion/stun/v3" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" - "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" wgConn "golang.zx2c4.com/wireguard/conn" @@ -28,22 +27,7 @@ type receiverCreator struct { } func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc { - if ipv4PC, ok := pc.(*ipv4.PacketConn); ok { - return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool) - } - // IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the - // wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6. - return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { - buf := bufs[0] - size, ep, err := conn.ReadFromUDPAddrPort(buf) - if err != nil { - return 0, err - } - sizes[0] = size - stdEp := &wgConn.StdNetEndpoint{AddrPort: ep} - eps[0] = stdEp - return 1, nil - } + return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool) } // ICEBind is a bind implementation with two main features: @@ -73,6 +57,8 @@ type ICEBind struct { muUDPMux sync.Mutex udpMux *udpmux.UniversalUDPMuxDefault + ipv4Conn *net.UDPConn + ipv6Conn *net.UDPConn } func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { @@ -118,6 +104,12 @@ func (s *ICEBind) Close() error { close(s.closedChan) + s.muUDPMux.Lock() + s.ipv4Conn = nil + s.ipv6Conn = nil + s.udpMux = nil + s.muUDPMux.Unlock() + return s.StdNetBind.Close() } @@ -175,19 +167,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } -func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { +func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() - s.udpMux = udpmux.NewUniversalUDPMuxDefault( - udpmux.UniversalUDPMuxParams{ - UDPConn: nbnet.WrapPacketConn(conn), - Net: s.transportNet, - FilterFn: s.filterFn, - WGAddress: s.address, - MTU: s.mtu, - }, - ) + // Detect IPv4 vs IPv6 from connection's local address + if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil { + s.ipv4Conn = conn + } else { + s.ipv6Conn = conn + } + s.createOrUpdateMux() + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { msgs := getMessages(msgsPool) for i := range bufs { @@ -195,12 +186,13 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] } defer putMessages(msgs, msgsPool) + var numMsgs int if runtime.GOOS == "linux" || runtime.GOOS == "android" { if rxOffload { readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams) - //nolint - numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0) + //nolint:staticcheck + _, err = pc.ReadBatch((*msgs)[readAt:], 0) if err != nil { return 0, err } @@ -222,12 +214,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r } numMsgs = 1 } + for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] // todo: handle err - ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr) - if ok { + if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok { continue } sizes[i] = msg.N @@ -248,6 +240,38 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r } } +// createOrUpdateMux creates or updates the UDP mux with the available connections. +// Must be called with muUDPMux held. +func (s *ICEBind) createOrUpdateMux() { + var muxConn net.PacketConn + + switch { + case s.ipv4Conn != nil && s.ipv6Conn != nil: + muxConn = NewDualStackPacketConn( + nbnet.WrapPacketConn(s.ipv4Conn), + nbnet.WrapPacketConn(s.ipv6Conn), + ) + case s.ipv4Conn != nil: + muxConn = nbnet.WrapPacketConn(s.ipv4Conn) + case s.ipv6Conn != nil: + muxConn = nbnet.WrapPacketConn(s.ipv6Conn) + default: + return + } + + // Don't close the old mux - it doesn't own the underlying connections. + // The sockets are managed by WireGuard's StdNetBind, not by us. + s.udpMux = udpmux.NewUniversalUDPMuxDefault( + udpmux.UniversalUDPMuxParams{ + UDPConn: muxConn, + Net: s.transportNet, + FilterFn: s.filterFn, + WGAddress: s.address, + MTU: s.mtu, + }, + ) +} + func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) { for i := range buffers { if !stun.IsMessage(buffers[i]) { @@ -260,9 +284,14 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) return true, err } - muxErr := s.udpMux.HandleSTUNMessage(msg, addr) - if muxErr != nil { - log.Warnf("failed to handle STUN packet") + s.muUDPMux.Lock() + mux := s.udpMux + s.muUDPMux.Unlock() + + if mux != nil { + if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil { + log.Warnf("failed to handle STUN packet: %v", muxErr) + } } buffers[i] = []byte{} diff --git a/client/iface/bind/ice_bind_test.go b/client/iface/bind/ice_bind_test.go new file mode 100644 index 000000000..1fdd955c9 --- /dev/null +++ b/client/iface/bind/ice_bind_test.go @@ -0,0 +1,324 @@ +package bind + +import ( + "fmt" + "net" + "net/netip" + "sync" + "testing" + "time" + + "github.com/pion/transport/v3/stdnet" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + + "github.com/netbirdio/netbird/client/iface/wgaddr" +) + +func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) { + iceBind := setupICEBind(t) + + ipv4Conn, ipv6Conn := createDualStackConns(t) + defer ipv4Conn.Close() + defer ipv6Conn.Close() + + rc := receiverCreator{iceBind} + pool := createMsgPool() + + // Simulate wireguard-go calling CreateReceiverFn for IPv4 + ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool) + require.NotNil(t, ipv4RecvFn) + + iceBind.muUDPMux.Lock() + assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection") + assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet") + assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection") + iceBind.muUDPMux.Unlock() + + // Simulate wireguard-go calling CreateReceiverFn for IPv6 + ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool) + require.NotNil(t, ipv6RecvFn) + + iceBind.muUDPMux.Lock() + assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection") + assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection") + assert.NotNil(t, iceBind.udpMux, "mux should still exist") + iceBind.muUDPMux.Unlock() + + mux, err := iceBind.GetICEMux() + require.NoError(t, err) + require.NotNil(t, mux) +} + +func TestICEBind_WorksWithIPv4Only(t *testing.T) { + iceBind := setupICEBind(t) + + ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + require.NoError(t, err) + defer ipv4Conn.Close() + + rc := receiverCreator{iceBind} + recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool()) + require.NotNil(t, recvFn) + + iceBind.muUDPMux.Lock() + assert.NotNil(t, iceBind.ipv4Conn) + assert.Nil(t, iceBind.ipv6Conn) + assert.NotNil(t, iceBind.udpMux) + iceBind.muUDPMux.Unlock() + + mux, err := iceBind.GetICEMux() + require.NoError(t, err) + require.NotNil(t, mux) +} + +func TestICEBind_WorksWithIPv6Only(t *testing.T) { + iceBind := setupICEBind(t) + + ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + t.Skipf("IPv6 not available: %v", err) + } + defer ipv6Conn.Close() + + rc := receiverCreator{iceBind} + recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool()) + require.NotNil(t, recvFn) + + iceBind.muUDPMux.Lock() + assert.Nil(t, iceBind.ipv4Conn) + assert.NotNil(t, iceBind.ipv6Conn) + assert.NotNil(t, iceBind.udpMux) + iceBind.muUDPMux.Unlock() + + mux, err := iceBind.GetICEMux() + require.NoError(t, err) + require.NotNil(t, mux) +} + +// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate +// with peers on different address families through the same DualStackPacketConn. +func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) { + // two "remote peers" listening on different address families + ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Peer.Close() + + ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) + if err != nil { + t.Skipf("IPv6 not available: %v", err) + } + defer ipv6Peer.Close() + + // our local dual-stack connection + ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Local.Close() + + ipv6Local := listenUDP(t, "udp6", "[::1]:0") + defer ipv6Local.Close() + + dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local) + + // send to both peers + _, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr()) + require.NoError(t, err) + + _, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr()) + require.NoError(t, err) + + // verify IPv4 peer got its packet from the IPv4 socket + buf := make([]byte, 100) + _ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second)) + n, addr, err := ipv4Peer.ReadFrom(buf) + require.NoError(t, err) + assert.Equal(t, "to-ipv4", string(buf[:n])) + assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port) + + // verify IPv6 peer got its packet from the IPv6 socket + _ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second)) + n, addr, err = ipv6Peer.ReadFrom(buf) + require.NoError(t, err) + assert.Equal(t, "to-ipv6", string(buf[:n])) + assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port) +} + +// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4 +// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets, +// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP. +func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) { + ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Peer.Close() + + ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) + if err != nil { + t.Skipf("IPv6 not available: %v", err) + } + defer ipv6Peer.Close() + + ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Local.Close() + + ipv6Local := listenUDP(t, "udp6", "[::1]:0") + defer ipv6Local.Close() + + dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local) + + const packetsPerFamily = 500 + + ipv4Received := make(chan string, packetsPerFamily) + ipv6Received := make(chan string, packetsPerFamily) + + startGate := make(chan struct{}) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 100) + for i := 0; i < packetsPerFamily; i++ { + n, _, err := ipv4Peer.ReadFrom(buf) + if err != nil { + return + } + ipv4Received <- string(buf[:n]) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 100) + for i := 0; i < packetsPerFamily; i++ { + n, _, err := ipv6Peer.ReadFrom(buf) + if err != nil { + return + } + ipv6Received <- string(buf[:n]) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + <-startGate + for i := 0; i < packetsPerFamily; i++ { + _, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr()) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + <-startGate + for i := 0; i < packetsPerFamily; i++ { + _, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr()) + } + }() + + close(startGate) + + time.AfterFunc(5*time.Second, func() { + _ = ipv4Peer.SetReadDeadline(time.Now()) + _ = ipv6Peer.SetReadDeadline(time.Now()) + }) + + wg.Wait() + close(ipv4Received) + close(ipv6Received) + + ipv4Count := 0 + for pkt := range ipv4Received { + require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt) + ipv4Count++ + } + + ipv6Count := 0 + for pkt := range ipv6Received { + require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt) + ipv6Count++ + } + + assert.Equal(t, packetsPerFamily, ipv4Count) + assert.Equal(t, packetsPerFamily, ipv6Count) +} + +func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) { + tests := []struct { + name string + network string + addr string + wantIPv4 bool + }{ + {"IPv4 any", "udp4", "0.0.0.0:0", true}, + {"IPv4 loopback", "udp4", "127.0.0.1:0", true}, + {"IPv6 any", "udp6", "[::]:0", false}, + {"IPv6 loopback", "udp6", "[::1]:0", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, err := net.ResolveUDPAddr(tt.network, tt.addr) + require.NoError(t, err) + + conn, err := net.ListenUDP(tt.network, addr) + if err != nil { + t.Skipf("%s not available: %v", tt.network, err) + } + defer conn.Close() + + localAddr := conn.LocalAddr().(*net.UDPAddr) + isIPv4 := localAddr.IP.To4() != nil + assert.Equal(t, tt.wantIPv4, isIPv4) + }) + } +} + +// helpers + +func setupICEBind(t *testing.T) *ICEBind { + t.Helper() + transportNet, err := stdnet.NewNet() + require.NoError(t, err) + + address := wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/10"), + } + return NewICEBind(transportNet, nil, address, 1280) +} + +func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) { + t.Helper() + ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + require.NoError(t, err) + + ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + ipv4Conn.Close() + t.Skipf("IPv6 not available: %v", err) + } + return ipv4Conn, ipv6Conn +} + +func createMsgPool() *sync.Pool { + return &sync.Pool{ + New: func() any { + msgs := make([]ipv6.Message, 1) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, 0, 40) + } + return &msgs + }, + } +} + +func listenUDP(t *testing.T, network, addr string) *net.UDPConn { + t.Helper() + udpAddr, err := net.ResolveUDPAddr(network, addr) + require.NoError(t, err) + conn, err := net.ListenUDP(network, udpAddr) + require.NoError(t, err) + return conn +} diff --git a/client/iface/iface.go b/client/iface/iface.go index 71fd433ad..e5623c979 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -50,6 +50,7 @@ func ValidateMTU(mtu uint16) error { type wgProxyFactory interface { GetProxy() wgproxy.Proxy + GetProxyPort() uint16 Free() error } @@ -80,6 +81,12 @@ func (w *WGIface) GetProxy() wgproxy.Proxy { return w.wgProxyFactory.GetProxy() } +// GetProxyPort returns the proxy port used by the WireGuard proxy. +// Returns 0 if no proxy port is used (e.g., for userspace WireGuard). +func (w *WGIface) GetProxyPort() uint16 { + return w.wgProxyFactory.GetProxyPort() +} + // GetBind returns the EndpointManager userspace bind mode. func (w *WGIface) GetBind() device.EndpointManager { w.mu.Lock() diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index eb585d8a2..9ac3ea6df 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -114,21 +114,21 @@ func (p *ProxyBind) Pause() { } func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) { + ep, err := addrToEndpoint(endpoint) + if err != nil { + log.Errorf("failed to start package redirection: %v", err) + return + } + p.pausedCond.L.Lock() p.paused = false - p.wgCurrentUsed = addrToEndpoint(endpoint) + p.wgCurrentUsed = ep p.pausedCond.Signal() p.pausedCond.L.Unlock() } -func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { - ip, _ := netip.AddrFromSlice(addr.IP.To4()) - addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) - return &bind.Endpoint{AddrPort: addrPort} -} - func (p *ProxyBind) CloseConn() error { if p.cancel == nil { return fmt.Errorf("proxy not started") @@ -212,3 +212,16 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) return &netipAddr, nil } + +func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) { + if addr == nil { + return nil, fmt.Errorf("invalid address") + } + ip, ok := netip.AddrFromSlice(addr.IP) + if !ok { + return nil, fmt.Errorf("convert %s to netip.Addr", addr) + } + + addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port)) + return &bind.Endpoint{AddrPort: addrPort}, nil +} diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index 858143091..5458519fa 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -27,12 +27,19 @@ const ( ) var ( - localHostNetIP = net.ParseIP("127.0.0.1") + localHostNetIPv4 = net.ParseIP("127.0.0.1") + localHostNetIPv6 = net.ParseIP("::1") + + serializeOpts = gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } ) // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { localWGListenPort int + proxyPort int mtu uint16 ebpfManager ebpfMgr.Manager @@ -40,7 +47,8 @@ type WGEBPFProxy struct { turnConnMutex sync.Mutex lastUsedPort uint16 - rawConn net.PacketConn + rawConnIPv4 net.PacketConn + rawConnIPv6 net.PacketConn conn transport.UDPConn ctx context.Context @@ -62,23 +70,39 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy { // Listen load ebpf program and listen the proxy func (p *WGEBPFProxy) Listen() error { pl := portLookup{} - wgPorxyPort, err := pl.searchFreePort() + proxyPort, err := pl.searchFreePort() + if err != nil { + return err + } + p.proxyPort = proxyPort + + // Prepare IPv4 raw socket (required) + p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4() if err != nil { return err } - p.rawConn, err = rawsocket.PrepareSenderRawSocket() + // Prepare IPv6 raw socket (optional) + p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6() if err != nil { - return err + log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err) } - err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort) + err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort) if err != nil { + if closeErr := p.rawConnIPv4.Close(); closeErr != nil { + log.Warnf("failed to close IPv4 raw socket: %v", closeErr) + } + if p.rawConnIPv6 != nil { + if closeErr := p.rawConnIPv6.Close(); closeErr != nil { + log.Warnf("failed to close IPv6 raw socket: %v", closeErr) + } + } return err } addr := net.UDPAddr{ - Port: wgPorxyPort, + Port: proxyPort, IP: net.ParseIP(loopbackAddr), } @@ -94,7 +118,7 @@ func (p *WGEBPFProxy) Listen() error { p.conn = conn go p.proxyToRemote() - log.Infof("local wg proxy listening on: %d", wgPorxyPort) + log.Infof("local wg proxy listening on: %d", proxyPort) return nil } @@ -135,12 +159,25 @@ func (p *WGEBPFProxy) Free() error { result = multierror.Append(result, err) } - if err := p.rawConn.Close(); err != nil { - result = multierror.Append(result, err) + if p.rawConnIPv4 != nil { + if err := p.rawConnIPv4.Close(); err != nil { + result = multierror.Append(result, err) + } + } + + if p.rawConnIPv6 != nil { + if err := p.rawConnIPv6.Close(); err != nil { + result = multierror.Append(result, err) + } } return nberrors.FormatErrorOrNil(result) } +// GetProxyPort returns the proxy listening port. +func (p *WGEBPFProxy) GetProxyPort() uint16 { + return uint16(p.proxyPort) +} + // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { @@ -218,31 +255,60 @@ generatePort: } func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { - payload := gopacket.Payload(data) - ipH := &layers.IPv4{ - DstIP: localHostNetIP, - SrcIP: endpointAddr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, + + var ipH gopacket.SerializableLayer + var networkLayer gopacket.NetworkLayer + var dstIP net.IP + var rawConn net.PacketConn + + if endpointAddr.IP.To4() != nil { + // IPv4 path + ipv4 := &layers.IPv4{ + DstIP: localHostNetIPv4, + SrcIP: endpointAddr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + ipH = ipv4 + networkLayer = ipv4 + dstIP = localHostNetIPv4 + rawConn = p.rawConnIPv4 + } else { + // IPv6 path + if p.rawConnIPv6 == nil { + return fmt.Errorf("IPv6 raw socket not available") + } + ipv6 := &layers.IPv6{ + DstIP: localHostNetIPv6, + SrcIP: endpointAddr.IP, + Version: 6, + HopLimit: 64, + NextHeader: layers.IPProtocolUDP, + } + ipH = ipv6 + networkLayer = ipv6 + dstIP = localHostNetIPv6 + rawConn = p.rawConnIPv6 } + udpH := &layers.UDP{ SrcPort: layers.UDPPort(endpointAddr.Port), DstPort: layers.UDPPort(p.localWGListenPort), } - err := udpH.SetNetworkLayerForChecksum(ipH) - if err != nil { + if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil { return fmt.Errorf("set network layer for checksum: %w", err) } layerBuffer := gopacket.NewSerializeBuffer() + payload := gopacket.Payload(data) - err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload) - if err != nil { + if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil { return fmt.Errorf("serialize layers: %w", err) } - if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil { + + if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil { return fmt.Errorf("write to raw conn: %w", err) } return nil diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index ff44d30c0..5b98be7b4 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -41,7 +41,7 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { closeListener: listener.NewCloseListener(), } } -func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) @@ -91,6 +91,10 @@ func (p *ProxyWrapper) Pause() { } func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { + if endpoint == nil || endpoint.IP == nil { + log.Errorf("failed to start package redirection, endpoint is nil") + return + } p.pausedCond.L.Lock() p.paused = false diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index 2714c5774..7821df3de 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -54,6 +54,14 @@ func (w *KernelFactory) GetProxy() Proxy { return ebpf.NewProxyWrapper(w.ebpfProxy) } +// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active. +func (w *KernelFactory) GetProxyPort() uint16 { + if w.ebpfProxy == nil { + return 0 + } + return w.ebpfProxy.GetProxyPort() +} + func (w *KernelFactory) Free() error { if w.ebpfProxy == nil { return nil diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go index a1b1c34d7..bbd67e076 100644 --- a/client/iface/wgproxy/factory_usp.go +++ b/client/iface/wgproxy/factory_usp.go @@ -24,6 +24,11 @@ func (w *USPFactory) GetProxy() Proxy { return proxyBind.NewProxyBind(w.bind, w.mtu) } +// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port. +func (w *USPFactory) GetProxyPort() uint16 { + return 0 +} + func (w *USPFactory) Free() error { return nil } diff --git a/client/iface/wgproxy/rawsocket/rawsocket.go b/client/iface/wgproxy/rawsocket/rawsocket.go index a11ac46d5..bc785b43a 100644 --- a/client/iface/wgproxy/rawsocket/rawsocket.go +++ b/client/iface/wgproxy/rawsocket/rawsocket.go @@ -8,43 +8,87 @@ import ( "os" "syscall" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + nbnet "github.com/netbirdio/netbird/client/net" ) -func PrepareSenderRawSocket() (net.PacketConn, error) { +// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets +func PrepareSenderRawSocketIPv4() (net.PacketConn, error) { + return prepareSenderRawSocket(syscall.AF_INET, true) +} + +// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets +func PrepareSenderRawSocketIPv6() (net.PacketConn, error) { + return prepareSenderRawSocket(syscall.AF_INET6, false) +} + +func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) { // Create a raw socket. - fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) + fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW) if err != nil { return nil, fmt.Errorf("creating raw socket failed: %w", err) } - // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. - err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) - if err != nil { - return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + // Set the header include option on the socket to tell the kernel that headers are included in the packet. + // For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers. + if isIPv4 { + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1) + if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } + return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + } + } else { + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1) + if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } + return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err) + } } // Bind the socket to the "lo" interface. err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } return nil, fmt.Errorf("binding to lo interface failed: %w", err) } // Set the fwmark on the socket. err = nbnet.SetSocketOpt(fd) if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } return nil, fmt.Errorf("setting fwmark failed: %w", err) } // Convert the file descriptor to a PacketConn. file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) if file == nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } return nil, fmt.Errorf("converting fd to file failed") } packetConn, err := net.FilePacketConn(file) if err != nil { + if closeErr := file.Close(); closeErr != nil { + log.Warnf("failed to close file: %v", closeErr) + } return nil, fmt.Errorf("converting file to packet conn failed: %w", err) } + // Close the original file to release the FD (net.FilePacketConn duplicates it) + if closeErr := file.Close(); closeErr != nil { + log.Warnf("failed to close file after creating packet conn: %v", closeErr) + } + return packetConn, nil } diff --git a/client/iface/wgproxy/redirect_test.go b/client/iface/wgproxy/redirect_test.go new file mode 100644 index 000000000..b52eead25 --- /dev/null +++ b/client/iface/wgproxy/redirect_test.go @@ -0,0 +1,353 @@ +//go:build linux && !android + +package wgproxy + +import ( + "context" + "net" + "testing" + "time" + + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/iface/wgproxy/udp" +) + +// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs +// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore +func compareUDPAddr(addr1, addr2 net.Addr) bool { + udpAddr1, ok1 := addr1.(*net.UDPAddr) + udpAddr2, ok2 := addr2.(*net.UDPAddr) + + if !ok1 || !ok2 { + return addr1.String() == addr2.String() + } + + // Compare IP and Port, ignoring zone + return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port +} + +// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses +func TestRedirectAs_eBPF_IPv4(t *testing.T) { + wgPort := 51850 + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %v", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %v", err) + } + }() + + proxy := ebpf.NewProxyWrapper(ebpfProxy) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("192.168.0.56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses +func TestRedirectAs_eBPF_IPv6(t *testing.T) { + wgPort := 51851 + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %v", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %v", err) + } + }() + + proxy := ebpf.NewProxyWrapper(ebpfProxy) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("fe80::56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses +func TestRedirectAs_UDP_IPv4(t *testing.T) { + wgPort := 51852 + proxy := udp.NewWGUDPProxy(wgPort, 1280) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("192.168.0.56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses +func TestRedirectAs_UDP_IPv6(t *testing.T) { + wgPort := 51853 + proxy := udp.NewWGUDPProxy(wgPort, 1280) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("fe80::56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// testRedirectAs is a helper function that tests the RedirectAs functionality +// It verifies that: +// 1. Initial traffic from relay connection works +// 2. After calling RedirectAs, packets appear to come from the p2p endpoint +// 3. Multiple packets are correctly redirected with the new source address +func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) { + t.Helper() + + ctx := context.Background() + + // Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types + // In reality, WireGuard binds to a port and receives from both IPv4 and IPv6 + wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: wgPort, + }) + if err != nil { + t.Fatalf("failed to create IPv4 WireGuard listener: %v", err) + } + defer wgListener4.Close() + + wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{ + IP: net.ParseIP("::1"), + Port: wgPort, + }) + if err != nil { + t.Fatalf("failed to create IPv6 WireGuard listener: %v", err) + } + defer wgListener6.Close() + + // Determine which listener to use based on the NetBird address IP version + // (this is where initial traffic will come from before RedirectAs is called) + var wgListener *net.UDPConn + if p2pEndpoint.IP.To4() == nil { + wgListener = wgListener6 + } else { + wgListener = wgListener4 + } + + // Create relay server and connection + relayServer, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, // Random port + }) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + relayConn, err := net.Dial("udp", relayServer.LocalAddr().String()) + if err != nil { + t.Fatalf("failed to create relay connection: %v", err) + } + defer relayConn.Close() + + // Add TURN connection to proxy + if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil { + t.Fatalf("failed to add TURN connection: %v", err) + } + defer func() { + if err := proxy.CloseConn(); err != nil { + t.Errorf("failed to close proxy connection: %v", err) + } + }() + + // Start the proxy + proxy.Work() + + // Phase 1: Test initial relay traffic + msgFromRelay := []byte("hello from relay") + if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil { + t.Fatalf("failed to write to relay server: %v", err) + } + + // Set read deadline to avoid hanging + if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("failed to set read deadline: %v", err) + } + + buf := make([]byte, 1024) + n, _, err := wgListener4.ReadFrom(buf) + if err != nil { + t.Fatalf("failed to read from WireGuard listener: %v", err) + } + + if n != len(msgFromRelay) { + t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n) + } + + if string(buf[:n]) != string(msgFromRelay) { + t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n]) + } + + // Phase 2: Redirect to p2p endpoint + proxy.RedirectAs(p2pEndpoint) + + // Give the proxy a moment to process the redirect + time.Sleep(100 * time.Millisecond) + + // Phase 3: Test redirected traffic + redirectedMessages := [][]byte{ + []byte("redirected message 1"), + []byte("redirected message 2"), + []byte("redirected message 3"), + } + + for i, msg := range redirectedMessages { + if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil { + t.Fatalf("failed to write redirected message %d: %v", i+1, err) + } + + if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("failed to set read deadline: %v", err) + } + + n, srcAddr, err := wgListener.ReadFrom(buf) + if err != nil { + t.Fatalf("failed to read redirected message %d: %v", i+1, err) + } + + // Verify message content + if string(buf[:n]) != string(msg) { + t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n]) + } + + // Verify source address matches p2p endpoint (this is the key test) + // Use compareUDPAddr to ignore IPv6 zone IDs + if !compareUDPAddr(srcAddr, p2pEndpoint) { + t.Errorf("message %d: expected source address %s, got %s", + i+1, p2pEndpoint.String(), srcAddr.String()) + } + } +} + +// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints +func TestRedirectAs_Multiple_Switches(t *testing.T) { + wgPort := 51856 + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %v", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %v", err) + } + }() + + proxy := ebpf.NewProxyWrapper(ebpfProxy) + + ctx := context.Background() + + // Create WireGuard listener + wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: wgPort, + }) + if err != nil { + t.Fatalf("failed to create WireGuard listener: %v", err) + } + defer wgListener.Close() + + // Create relay server and connection + relayServer, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + }) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + relayConn, err := net.Dial("udp", relayServer.LocalAddr().String()) + if err != nil { + t.Fatalf("failed to create relay connection: %v", err) + } + defer relayConn.Close() + + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil { + t.Fatalf("failed to add TURN connection: %v", err) + } + defer func() { + if err := proxy.CloseConn(); err != nil { + t.Errorf("failed to close proxy connection: %v", err) + } + }() + + proxy.Work() + + // Test switching between multiple endpoints - using addresses in local subnet + endpoints := []*net.UDPAddr{ + {IP: net.ParseIP("192.168.0.100"), Port: 51820}, + {IP: net.ParseIP("192.168.0.101"), Port: 51821}, + {IP: net.ParseIP("192.168.0.102"), Port: 51822}, + } + + for i, endpoint := range endpoints { + proxy.RedirectAs(endpoint) + time.Sleep(100 * time.Millisecond) + + msg := []byte("test message") + if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil { + t.Fatalf("failed to write message for endpoint %d: %v", i, err) + } + + buf := make([]byte, 1024) + if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("failed to set read deadline: %v", err) + } + + n, srcAddr, err := wgListener.ReadFrom(buf) + if err != nil { + t.Fatalf("failed to read message for endpoint %d: %v", i, err) + } + + if string(buf[:n]) != string(msg) { + t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n]) + } + + if !compareUDPAddr(srcAddr, endpoint) { + t.Errorf("endpoint %d: expected source %s, got %s", + i, endpoint.String(), srcAddr.String()) + } + } +} diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index 4ef2f19c4..6069d1960 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy { // the connection is complete, an error is returned. Once successfully // connected, any expiration of the context will not affect the // connection. -func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { +func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { dialer := net.Dialer{} localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { diff --git a/client/iface/wgproxy/udp/rawsocket.go b/client/iface/wgproxy/udp/rawsocket.go index fdc911463..cc099d9df 100644 --- a/client/iface/wgproxy/udp/rawsocket.go +++ b/client/iface/wgproxy/udp/rawsocket.go @@ -19,37 +19,56 @@ var ( FixLengths: true, } - localHostNetIPAddr = &net.IPAddr{ + localHostNetIPAddrV4 = &net.IPAddr{ IP: net.ParseIP("127.0.0.1"), } + localHostNetIPAddrV6 = &net.IPAddr{ + IP: net.ParseIP("::1"), + } ) type SrcFaker struct { srcAddr *net.UDPAddr - rawSocket net.PacketConn - ipH gopacket.SerializableLayer - udpH gopacket.SerializableLayer - layerBuffer gopacket.SerializeBuffer + rawSocket net.PacketConn + ipH gopacket.SerializableLayer + udpH gopacket.SerializableLayer + layerBuffer gopacket.SerializeBuffer + localHostAddr *net.IPAddr } func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) { - rawSocket, err := rawsocket.PrepareSenderRawSocket() + // Create only the raw socket for the address family we need + var rawSocket net.PacketConn + var err error + var localHostAddr *net.IPAddr + + if srcAddr.IP.To4() != nil { + rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4() + localHostAddr = localHostNetIPAddrV4 + } else { + rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6() + localHostAddr = localHostNetIPAddrV6 + } if err != nil { return nil, err } ipH, udpH, err := prepareHeaders(dstPort, srcAddr) if err != nil { + if closeErr := rawSocket.Close(); closeErr != nil { + log.Warnf("failed to close raw socket: %v", closeErr) + } return nil, err } f := &SrcFaker{ - srcAddr: srcAddr, - rawSocket: rawSocket, - ipH: ipH, - udpH: udpH, - layerBuffer: gopacket.NewSerializeBuffer(), + srcAddr: srcAddr, + rawSocket: rawSocket, + ipH: ipH, + udpH: udpH, + layerBuffer: gopacket.NewSerializeBuffer(), + localHostAddr: localHostAddr, } return f, nil @@ -72,7 +91,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) { if err != nil { return 0, fmt.Errorf("serialize layers: %w", err) } - n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr) + n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr) if err != nil { return 0, fmt.Errorf("write to raw conn: %w", err) } @@ -80,19 +99,40 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) { } func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) { - ipH := &layers.IPv4{ - DstIP: net.ParseIP("127.0.0.1"), - SrcIP: srcAddr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, + var ipH gopacket.SerializableLayer + var networkLayer gopacket.NetworkLayer + + // Check if source IP is IPv4 or IPv6 + if srcAddr.IP.To4() != nil { + // IPv4 + ipv4 := &layers.IPv4{ + DstIP: localHostNetIPAddrV4.IP, + SrcIP: srcAddr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + ipH = ipv4 + networkLayer = ipv4 + } else { + // IPv6 + ipv6 := &layers.IPv6{ + DstIP: localHostNetIPAddrV6.IP, + SrcIP: srcAddr.IP, + Version: 6, + HopLimit: 64, + NextHeader: layers.IPProtocolUDP, + } + ipH = ipv6 + networkLayer = ipv6 } + udpH := &layers.UDP{ SrcPort: layers.UDPPort(srcAddr.Port), DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port } - err := udpH.SetNetworkLayerForChecksum(ipH) + err := udpH.SetNetworkLayerForChecksum(networkLayer) if err != nil { return nil, nil, fmt.Errorf("set network layer for checksum: %w", err) } diff --git a/client/internal/auth/auth.go b/client/internal/auth/auth.go new file mode 100644 index 000000000..44e98bede --- /dev/null +++ b/client/internal/auth/auth.go @@ -0,0 +1,499 @@ +package auth + +import ( + "context" + "net/url" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/system" + mgm "github.com/netbirdio/netbird/shared/management/client" + "github.com/netbirdio/netbird/shared/management/client/common" + mgmProto "github.com/netbirdio/netbird/shared/management/proto" +) + +// Auth manages authentication operations with the management server +// It maintains a long-lived connection and automatically handles reconnection with backoff +type Auth struct { + mutex sync.RWMutex + client *mgm.GrpcClient + config *profilemanager.Config + privateKey wgtypes.Key + mgmURL *url.URL + mgmTLSEnabled bool +} + +// NewAuth creates a new Auth instance that manages authentication flows +// It establishes a connection to the management server that will be reused for all operations +// The connection is automatically recreated with backoff if it becomes disconnected +func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) { + // Validate WireGuard private key + myPrivateKey, err := wgtypes.ParseKey(privateKey) + if err != nil { + return nil, err + } + + // Determine TLS setting based on URL scheme + mgmTLSEnabled := mgmURL.Scheme == "https" + + log.Debugf("connecting to Management Service %s", mgmURL.String()) + mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled) + if err != nil { + log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err) + return nil, err + } + + log.Debugf("connected to the Management service %s", mgmURL.String()) + + return &Auth{ + client: mgmClient, + config: config, + privateKey: myPrivateKey, + mgmURL: mgmURL, + mgmTLSEnabled: mgmTLSEnabled, + }, nil +} + +// Close closes the management client connection +func (a *Auth) Close() error { + a.mutex.Lock() + defer a.mutex.Unlock() + + if a.client == nil { + return nil + } + return a.client.Close() +} + +// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations. +// Returns true if either PKCE or Device authorization flow is supported, false otherwise. +// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers. +// Automatically retries with backoff and reconnection on connection errors. +func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) { + var supportsSSO bool + + err := a.withRetry(ctx, func(client *mgm.GrpcClient) error { + // Try PKCE flow first + _, err := a.getPKCEFlow(client) + if err == nil { + supportsSSO = true + return nil + } + + // Check if PKCE is not supported + if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { + // PKCE not supported, try Device flow + _, err = a.getDeviceFlow(client) + if err == nil { + supportsSSO = true + return nil + } + + // Check if Device flow is also not supported + if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { + // Neither PKCE nor Device flow is supported + supportsSSO = false + return nil + } + + // Device flow check returned an error other than NotFound/Unimplemented + return err + } + + // PKCE flow check returned an error other than NotFound/Unimplemented + return err + }) + + return supportsSSO, err +} + +// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection +// This avoids creating a new connection to the management server +func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) { + var flow OAuthFlow + var err error + + err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { + if forceDeviceAuth { + flow, err = a.getDeviceFlow(client) + return err + } + + // Try PKCE flow first + flow, err = a.getPKCEFlow(client) + if err != nil { + // If PKCE not supported, try Device flow + if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { + flow, err = a.getDeviceFlow(client) + return err + } + return err + } + return nil + }) + + return flow, err +} + +// IsLoginRequired checks if login is required by attempting to authenticate with the server +// Automatically retries with backoff and reconnection on connection errors. +func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) { + pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey)) + if err != nil { + return false, err + } + + var needsLogin bool + + err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { + _, _, err := a.doMgmLogin(client, ctx, pubSSHKey) + if isLoginNeeded(err) { + needsLogin = true + return nil + } + needsLogin = false + return err + }) + + return needsLogin, err +} + +// Login attempts to log in or register the client with the management server +// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries. +// Automatically retries with backoff and reconnection on connection errors. +func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) { + pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey)) + if err != nil { + return err, false + } + + var isAuthError bool + + err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { + serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey) + if serverKey != nil && isRegistrationNeeded(err) { + log.Debugf("peer registration required") + _, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey) + if err != nil { + isAuthError = isPermissionDenied(err) + return err + } + } else if err != nil { + isAuthError = isPermissionDenied(err) + return err + } + + isAuthError = false + return nil + }) + + return err, isAuthError +} + +// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance +func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) { + serverKey, err := client.GetServerPublicKey() + if err != nil { + log.Errorf("failed while getting Management Service public key: %v", err) + return nil, err + } + + protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey) + if err != nil { + if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { + log.Warnf("server couldn't find pkce flow, contact admin: %v", err) + return nil, err + } + log.Errorf("failed to retrieve pkce flow: %v", err) + return nil, err + } + + protoConfig := protoFlow.GetProviderConfig() + config := &PKCEAuthProviderConfig{ + Audience: protoConfig.GetAudience(), + ClientID: protoConfig.GetClientID(), + ClientSecret: protoConfig.GetClientSecret(), + TokenEndpoint: protoConfig.GetTokenEndpoint(), + AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(), + Scope: protoConfig.GetScope(), + RedirectURLs: protoConfig.GetRedirectURLs(), + UseIDToken: protoConfig.GetUseIDToken(), + ClientCertPair: a.config.ClientCertKeyPair, + DisablePromptLogin: protoConfig.GetDisablePromptLogin(), + LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()), + } + + if err := validatePKCEConfig(config); err != nil { + return nil, err + } + + flow, err := NewPKCEAuthorizationFlow(*config) + if err != nil { + return nil, err + } + + return flow, nil +} + +// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance +func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) { + serverKey, err := client.GetServerPublicKey() + if err != nil { + log.Errorf("failed while getting Management Service public key: %v", err) + return nil, err + } + + protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey) + if err != nil { + if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { + log.Warnf("server couldn't find device flow, contact admin: %v", err) + return nil, err + } + log.Errorf("failed to retrieve device flow: %v", err) + return nil, err + } + + protoConfig := protoFlow.GetProviderConfig() + config := &DeviceAuthProviderConfig{ + Audience: protoConfig.GetAudience(), + ClientID: protoConfig.GetClientID(), + ClientSecret: protoConfig.GetClientSecret(), + Domain: protoConfig.Domain, + TokenEndpoint: protoConfig.GetTokenEndpoint(), + DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(), + Scope: protoConfig.GetScope(), + UseIDToken: protoConfig.GetUseIDToken(), + } + + // Keep compatibility with older management versions + if config.Scope == "" { + config.Scope = "openid" + } + + if err := validateDeviceAuthConfig(config); err != nil { + return nil, err + } + + flow, err := NewDeviceAuthorizationFlow(*config) + if err != nil { + return nil, err + } + + return flow, nil +} + +// doMgmLogin performs the actual login operation with the management service +func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) { + serverKey, err := client.GetServerPublicKey() + if err != nil { + log.Errorf("failed while getting Management Service public key: %v", err) + return nil, nil, err + } + + sysInfo := system.GetInfo(ctx) + a.setSystemInfoFlags(sysInfo) + loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels) + return serverKey, loginResp, err +} + +// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. +// Otherwise tries to register with the provided setupKey via command line. +func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) { + serverPublicKey, err := client.GetServerPublicKey() + if err != nil { + log.Errorf("failed while getting Management Service public key: %v", err) + return nil, err + } + + validSetupKey, err := uuid.Parse(setupKey) + if err != nil && jwtToken == "" { + return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err) + } + + log.Debugf("sending peer registration request to Management Service") + info := system.GetInfo(ctx) + a.setSystemInfoFlags(info) + loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels) + if err != nil { + log.Errorf("failed registering peer %v", err) + return nil, err + } + + log.Infof("peer has been successfully registered on Management Service") + + return loginResp, nil +} + +// setSystemInfoFlags sets all configuration flags on the provided system info +func (a *Auth) setSystemInfoFlags(info *system.Info) { + info.SetFlags( + a.config.RosenpassEnabled, + a.config.RosenpassPermissive, + a.config.ServerSSHAllowed, + a.config.DisableClientRoutes, + a.config.DisableServerRoutes, + a.config.DisableDNS, + a.config.DisableFirewall, + a.config.BlockLANAccess, + a.config.BlockInbound, + a.config.LazyConnectionEnabled, + a.config.EnableSSHRoot, + a.config.EnableSSHSFTP, + a.config.EnableSSHLocalPortForwarding, + a.config.EnableSSHRemotePortForwarding, + a.config.DisableSSHAuth, + ) +} + +// reconnect closes the current connection and creates a new one +// It checks if the brokenClient is still the current client before reconnecting +// to avoid multiple threads reconnecting unnecessarily +func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error { + a.mutex.Lock() + defer a.mutex.Unlock() + + // Double-check: if client has already been replaced by another thread, skip reconnection + if a.client != brokenClient { + log.Debugf("client already reconnected by another thread, skipping") + return nil + } + + // Create new connection FIRST, before closing the old one + // This ensures a.client is never nil, preventing panics in other threads + log.Debugf("reconnecting to Management Service %s", a.mgmURL.String()) + mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled) + if err != nil { + log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err) + // Keep the old client if reconnection fails + return err + } + + // Close old connection AFTER new one is successfully created + oldClient := a.client + a.client = mgmClient + + if oldClient != nil { + if err := oldClient.Close(); err != nil { + log.Debugf("error closing old connection: %v", err) + } + } + + log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String()) + return nil +} + +// isConnectionError checks if the error is a connection-related error that should trigger reconnection +func isConnectionError(err error) bool { + if err == nil { + return false + } + s, ok := status.FromError(err) + if !ok { + return false + } + // These error codes indicate connection issues + return s.Code() == codes.Unavailable || + s.Code() == codes.DeadlineExceeded || + s.Code() == codes.Canceled || + s.Code() == codes.Internal +} + +// withRetry wraps an operation with exponential backoff retry logic +// It automatically reconnects on connection errors +func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error { + backoffSettings := &backoff.ExponentialBackOff{ + InitialInterval: 500 * time.Millisecond, + RandomizationFactor: 0.5, + Multiplier: 1.5, + MaxInterval: 10 * time.Second, + MaxElapsedTime: 2 * time.Minute, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + } + backoffSettings.Reset() + + return backoff.RetryNotify( + func() error { + // Capture the client BEFORE the operation to ensure we track the correct client + a.mutex.RLock() + currentClient := a.client + a.mutex.RUnlock() + + if currentClient == nil { + return status.Errorf(codes.Unavailable, "client is not initialized") + } + + // Execute operation with the captured client + err := operation(currentClient) + if err == nil { + return nil + } + + // If it's a connection error, attempt reconnection using the client that was actually used + if isConnectionError(err) { + log.Warnf("connection error detected, attempting reconnection: %v", err) + + if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil { + log.Errorf("reconnection failed: %v", reconnectErr) + return reconnectErr + } + // Return the original error to trigger retry with the new connection + return err + } + + // For authentication errors, don't retry + if isAuthenticationError(err) { + return backoff.Permanent(err) + } + + return err + }, + backoff.WithContext(backoffSettings, ctx), + func(err error, duration time.Duration) { + log.Warnf("operation failed, retrying in %v: %v", duration, err) + }, + ) +} + +// isAuthenticationError checks if the error is an authentication-related error that should not be retried. +// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help. +func isAuthenticationError(err error) bool { + if err == nil { + return false + } + s, ok := status.FromError(err) + if !ok { + return false + } + return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied +} + +// isPermissionDenied checks if the error is a PermissionDenied error. +// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access). +func isPermissionDenied(err error) bool { + if err == nil { + return false + } + s, ok := status.FromError(err) + if !ok { + return false + } + return s.Code() == codes.PermissionDenied +} + +func isLoginNeeded(err error) bool { + return isAuthenticationError(err) +} + +func isRegistrationNeeded(err error) bool { + return isPermissionDenied(err) +} diff --git a/client/internal/auth/device_flow.go b/client/internal/auth/device_flow.go index 8ca760742..e33765300 100644 --- a/client/internal/auth/device_flow.go +++ b/client/internal/auth/device_flow.go @@ -15,7 +15,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/util/embeddedroots" ) @@ -26,12 +25,56 @@ const ( var _ OAuthFlow = &DeviceAuthorizationFlow{} +// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow +type DeviceAuthProviderConfig struct { + // ClientID An IDP application client id + ClientID string + // ClientSecret An IDP application client secret + ClientSecret string + // Domain An IDP API domain + // Deprecated. Use OIDCConfigEndpoint instead + Domain string + // Audience An Audience for to authorization validation + Audience string + // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token + TokenEndpoint string + // DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code + DeviceAuthEndpoint string + // Scopes provides the scopes to be included in the token request + Scope string + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string +} + +// validateDeviceAuthConfig validates device authorization provider configuration +func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error { + errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" + + if config.Audience == "" { + return fmt.Errorf(errorMsgFormat, "Audience") + } + if config.ClientID == "" { + return fmt.Errorf(errorMsgFormat, "Client ID") + } + if config.TokenEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Token Endpoint") + } + if config.DeviceAuthEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint") + } + if config.Scope == "" { + return fmt.Errorf(errorMsgFormat, "Device Auth Scopes") + } + return nil +} + // DeviceAuthorizationFlow implements the OAuthFlow interface, // for the Device Authorization Flow. type DeviceAuthorizationFlow struct { - providerConfig internal.DeviceAuthProviderConfig - - HTTPClient HTTPClient + providerConfig DeviceAuthProviderConfig + HTTPClient HTTPClient } // RequestDeviceCodePayload used for request device code payload for auth0 @@ -57,7 +100,7 @@ type TokenRequestResponse struct { } // NewDeviceAuthorizationFlow returns device authorization flow client -func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) { +func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) { httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 @@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string { return d.providerConfig.ClientID } +// SetLoginHint sets the login hint for the device authorization flow +func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) { + d.providerConfig.LoginHint = hint +} + // RequestAuthInfo requests a device code login flow information from Hosted func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) { form := url.Values{} @@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR } // WaitToken waits user's login and authorize the app. Once the user's authorize -// it retrieves the access token from Hosted's endpoint and validates it before returning +// it retrieves the access token from Hosted's endpoint and validates it before returning. +// The method creates a timeout context internally based on info.ExpiresIn. func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) { + // Create timeout context based on flow expiration + timeout := time.Duration(info.ExpiresIn) * time.Second + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + interval := time.Duration(info.Interval) * time.Second ticker := time.NewTicker(interval) + defer ticker.Stop() + for { select { - case <-ctx.Done(): - return TokenInfo{}, ctx.Err() + case <-waitCtx.Done(): + return TokenInfo{}, waitCtx.Err() case <-ticker.C: tokenResponse, err := d.requestToken(info) diff --git a/client/internal/auth/device_flow_test.go b/client/internal/auth/device_flow_test.go index 466645ee9..6a433cb61 100644 --- a/client/internal/auth/device_flow_test.go +++ b/client/internal/auth/device_flow_test.go @@ -12,8 +12,6 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/internal" ) type mockHTTPClient struct { @@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) { err: testCase.inputReqError, } - deviceFlow := &DeviceAuthorizationFlow{ - providerConfig: internal.DeviceAuthProviderConfig{ - Audience: expectedAudience, - ClientID: expectedClientID, - Scope: expectedScope, - TokenEndpoint: "test.hosted.com/token", - DeviceAuthEndpoint: "test.hosted.com/device/auth", - UseIDToken: false, - }, - HTTPClient: &httpClient, + config := DeviceAuthProviderConfig{ + Audience: expectedAudience, + ClientID: expectedClientID, + Scope: expectedScope, + TokenEndpoint: "test.hosted.com/token", + DeviceAuthEndpoint: "test.hosted.com/device/auth", + UseIDToken: false, } + deviceFlow, err := NewDeviceAuthorizationFlow(config) + require.NoError(t, err, "creating device flow should not fail") + deviceFlow.HTTPClient = &httpClient + authInfo, err := deviceFlow.RequestAuthInfo(context.TODO()) testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) @@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) { countResBody: testCase.inputCountResBody, } - deviceFlow := DeviceAuthorizationFlow{ - providerConfig: internal.DeviceAuthProviderConfig{ - Audience: testCase.inputAudience, - ClientID: clientID, - TokenEndpoint: "test.hosted.com/token", - DeviceAuthEndpoint: "test.hosted.com/device/auth", - Scope: "openid", - UseIDToken: false, - }, - HTTPClient: &httpClient, + config := DeviceAuthProviderConfig{ + Audience: testCase.inputAudience, + ClientID: clientID, + TokenEndpoint: "test.hosted.com/token", + DeviceAuthEndpoint: "test.hosted.com/device/auth", + Scope: "openid", + UseIDToken: false, } + deviceFlow, err := NewDeviceAuthorizationFlow(config) + require.NoError(t, err, "creating device flow should not fail") + deviceFlow.HTTPClient = &httpClient + ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout) defer cancel() tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 85a166005..a50a2ce6f 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -10,7 +10,6 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/profilemanager" ) @@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { - pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) + authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config) + if err != nil { + return nil, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + + pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client) if err != nil { return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) } - pkceFlowInfo.ProviderConfig.LoginHint = hint + if hint != "" { + pkceFlowInfo.SetLoginHint(hint) + } - return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) + return pkceFlowInfo, nil } // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { - deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) + authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config) + if err != nil { + return nil, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + + deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client) if err != nil { switch s, ok := gstatus.FromError(err); { case ok && s.Code() == codes.NotFound: @@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager. } } - deviceFlowInfo.ProviderConfig.LoginHint = hint + if hint != "" { + deviceFlowInfo.SetLoginHint(hint) + } - return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) + return deviceFlowInfo, nil } diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index cc43c8648..2e16836d8 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -20,7 +20,6 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/oauth2" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/templates" "github.com/netbirdio/netbird/shared/management/client/common" ) @@ -35,17 +34,67 @@ const ( defaultPKCETimeoutSeconds = 300 ) +// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow +type PKCEAuthProviderConfig struct { + // ClientID An IDP application client id + ClientID string + // ClientSecret An IDP application client secret + ClientSecret string + // Audience An Audience for to authorization validation + Audience string + // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token + TokenEndpoint string + // AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code + AuthorizationEndpoint string + // Scopes provides the scopes to be included in the token request + Scope string + // RedirectURL handles authorization code from IDP manager + RedirectURLs []string + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool + // ClientCertPair is used for mTLS authentication to the IDP + ClientCertPair *tls.Certificate + // DisablePromptLogin makes the PKCE flow to not prompt the user for login + DisablePromptLogin bool + // LoginFlag is used to configure the PKCE flow login behavior + LoginFlag common.LoginFlag + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string +} + +// validatePKCEConfig validates PKCE provider configuration +func validatePKCEConfig(config *PKCEAuthProviderConfig) error { + errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" + + if config.ClientID == "" { + return fmt.Errorf(errorMsgFormat, "Client ID") + } + if config.TokenEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Token Endpoint") + } + if config.AuthorizationEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint") + } + if config.Scope == "" { + return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes") + } + if config.RedirectURLs == nil { + return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs") + } + return nil +} + // PKCEAuthorizationFlow implements the OAuthFlow interface for // the Authorization Code Flow with PKCE. type PKCEAuthorizationFlow struct { - providerConfig internal.PKCEAuthProviderConfig + providerConfig PKCEAuthProviderConfig state string codeVerifier string oAuthConfig *oauth2.Config } // NewPKCEAuthorizationFlow returns new PKCE authorization code flow. -func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) { +func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) { var availableRedirectURL string excludedRanges := getSystemExcludedPortRanges() @@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn }, nil } +// SetLoginHint sets the login hint for the PKCE authorization flow +func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) { + p.providerConfig.LoginHint = hint +} + // WaitToken waits for the OAuth token in the PKCE Authorization Flow. // It starts an HTTP server to receive the OAuth token callback and waits for the token or an error. // Once the token is received, it is converted to TokenInfo and validated before returning. -func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) { +// The method creates a timeout context internally based on info.ExpiresIn. +func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) { + // Create timeout context based on flow expiration + timeout := time.Duration(info.ExpiresIn) * time.Second + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + tokenChan := make(chan *oauth2.Token, 1) errChan := make(chan error, 1) @@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} defer func() { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := server.Shutdown(shutdownCtx); err != nil { @@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( go p.startServer(server, tokenChan, errChan) select { - case <-ctx.Done(): - return TokenInfo{}, ctx.Err() + case <-waitCtx.Done(): + return TokenInfo{}, waitCtx.Err() case token := <-tokenChan: return p.parseOAuthToken(token) case err := <-errChan: diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index b77a17eaa..c487c13df 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/client/internal" mgm "github.com/netbirdio/netbird/shared/management/client/common" ) @@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - config := internal.PKCEAuthProviderConfig{ + config := PKCEAuthProviderConfig{ ClientID: "test-client-id", Audience: "test-audience", TokenEndpoint: "https://test-token-endpoint.com/token", diff --git a/client/internal/auth/pkce_flow_windows_test.go b/client/internal/auth/pkce_flow_windows_test.go index dd455b2fe..125eb270a 100644 --- a/client/internal/auth/pkce_flow_windows_test.go +++ b/client/internal/auth/pkce_flow_windows_test.go @@ -9,8 +9,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/internal" ) func TestParseExcludedPortRanges(t *testing.T) { @@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) { availablePort := 65432 - config := internal.PKCEAuthProviderConfig{ + config := PKCEAuthProviderConfig{ ClientID: "test-client-id", Audience: "test-audience", TokenEndpoint: "https://test-token-endpoint.com/token", diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index d7a24fa38..b2208e68f 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -59,6 +59,7 @@ block.prof: Block profiling information. heap.prof: Heap profiling information (snapshot of memory allocations). allocs.prof: Allocations profiling information. threadcreate.prof: Thread creation profiling information. +cpu.prof: CPU profiling information. stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation. @@ -231,6 +232,8 @@ type BundleGenerator struct { statusRecorder *peer.Status syncResponse *mgmProto.SyncResponse logPath string + cpuProfile []byte + refreshStatus func() // Optional callback to refresh status before bundle generation clientMetrics MetricsExporter anonymize bool @@ -251,6 +254,8 @@ type GeneratorDependencies struct { StatusRecorder *peer.Status SyncResponse *mgmProto.SyncResponse LogPath string + CPUProfile []byte + RefreshStatus func() // Optional callback to refresh status before bundle generation ClientMetrics MetricsExporter } @@ -268,6 +273,8 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen statusRecorder: deps.StatusRecorder, syncResponse: deps.SyncResponse, logPath: deps.LogPath, + cpuProfile: deps.CPUProfile, + refreshStatus: deps.RefreshStatus, clientMetrics: deps.ClientMetrics, anonymize: cfg.Anonymize, @@ -332,6 +339,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add profiles to debug bundle: %v", err) } + if err := g.addCPUProfile(); err != nil { + log.Errorf("failed to add CPU profile to debug bundle: %v", err) + } + if err := g.addStackTrace(); err != nil { log.Errorf("failed to add stack trace to debug bundle: %v", err) } @@ -412,6 +423,10 @@ func (g *BundleGenerator) addStatus() error { profName = activeProf.Name } + if g.refreshStatus != nil { + g.refreshStatus() + } + fullStatus := g.statusRecorder.GetFullStatus() protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus) protoFullStatus.Events = g.statusRecorder.GetEventHistory() @@ -554,6 +569,19 @@ func (g *BundleGenerator) addProf() (err error) { return nil } +func (g *BundleGenerator) addCPUProfile() error { + if len(g.cpuProfile) == 0 { + return nil + } + + reader := bytes.NewReader(g.cpuProfile) + if err := g.addFileToZip(reader, "cpu.prof"); err != nil { + return fmt.Errorf("add CPU profile to zip: %w", err) + } + + return nil +} + func (g *BundleGenerator) addStackTrace() error { buf := make([]byte, 5242880) // 5 MB buffer n := runtime.Stack(buf, true) diff --git a/client/internal/device_auth.go b/client/internal/device_auth.go deleted file mode 100644 index 7f7d06130..000000000 --- a/client/internal/device_auth.go +++ /dev/null @@ -1,136 +0,0 @@ -package internal - -import ( - "context" - "fmt" - "net/url" - - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - mgm "github.com/netbirdio/netbird/shared/management/client" -) - -// DeviceAuthorizationFlow represents Device Authorization Flow information -type DeviceAuthorizationFlow struct { - Provider string - ProviderConfig DeviceAuthProviderConfig -} - -// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow -type DeviceAuthProviderConfig struct { - // ClientID An IDP application client id - ClientID string - // ClientSecret An IDP application client secret - ClientSecret string - // Domain An IDP API domain - // Deprecated. Use OIDCConfigEndpoint instead - Domain string - // Audience An Audience for to authorization validation - Audience string - // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token - TokenEndpoint string - // DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code - DeviceAuthEndpoint string - // Scopes provides the scopes to be included in the token request - Scope string - // UseIDToken indicates if the id token should be used for authentication - UseIDToken bool - // LoginHint is used to pre-fill the email/username field during authentication - LoginHint string -} - -// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it -func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) { - // validate our peer's Wireguard PRIVATE key - myPrivateKey, err := wgtypes.ParseKey(privateKey) - if err != nil { - log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error()) - return DeviceAuthorizationFlow{}, err - } - - var mgmTLSEnabled bool - if mgmURL.Scheme == "https" { - mgmTLSEnabled = true - } - - log.Debugf("connecting to Management Service %s", mgmURL.String()) - mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled) - if err != nil { - log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err) - return DeviceAuthorizationFlow{}, err - } - log.Debugf("connected to the Management service %s", mgmURL.String()) - - defer func() { - err = mgmClient.Close() - if err != nil { - log.Warnf("failed to close the Management service client %v", err) - } - }() - - serverKey, err := mgmClient.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return DeviceAuthorizationFlow{}, err - } - - protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey) - if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { - log.Warnf("server couldn't find device flow, contact admin: %v", err) - return DeviceAuthorizationFlow{}, err - } - log.Errorf("failed to retrieve device flow: %v", err) - return DeviceAuthorizationFlow{}, err - } - - deviceAuthorizationFlow := DeviceAuthorizationFlow{ - Provider: protoDeviceAuthorizationFlow.Provider.String(), - - ProviderConfig: DeviceAuthProviderConfig{ - Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(), - ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(), - ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(), - Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain, - TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(), - DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(), - Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(), - UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(), - }, - } - - // keep compatibility with older management versions - if deviceAuthorizationFlow.ProviderConfig.Scope == "" { - deviceAuthorizationFlow.ProviderConfig.Scope = "openid" - } - - err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig) - if err != nil { - return DeviceAuthorizationFlow{}, err - } - - return deviceAuthorizationFlow, nil -} - -func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error { - errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" - if config.Audience == "" { - return fmt.Errorf(errorMSGFormat, "Audience") - } - if config.ClientID == "" { - return fmt.Errorf(errorMSGFormat, "Client ID") - } - if config.TokenEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Token Endpoint") - } - if config.DeviceAuthEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint") - } - if config.Scope == "" { - return fmt.Errorf(errorMSGFormat, "Device Auth Scopes") - } - return nil -} diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index cbdc64997..b374bcc6a 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -81,7 +81,10 @@ func (d *Resolver) ProbeAvailability() {} // ServeDNS handles a DNS request func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - logger := log.WithField("request_id", resutil.GetRequestID(w)) + logger := log.WithFields(log.Fields{ + "request_id": resutil.GetRequestID(w), + "dns_id": fmt.Sprintf("%04x", r.Id), + }) if len(r.Question) == 0 { logger.Debug("received local resolver request with no question") diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 654d280ef..0fbd32771 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -71,6 +71,11 @@ type upstreamResolverBase struct { statusRecorder *peer.Status } +type upstreamFailure struct { + upstream netip.AddrPort + reason string +} + func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase { ctx, cancel := context.WithCancel(ctx) @@ -114,7 +119,10 @@ func (u *upstreamResolverBase) Stop() { // ServeDNS handles a DNS request func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - logger := log.WithField("request_id", resutil.GetRequestID(w)) + logger := log.WithFields(log.Fields{ + "request_id": resutil.GetRequestID(w), + "dns_id": fmt.Sprintf("%04x", r.Id), + }) u.prepareRequest(r) @@ -123,11 +131,13 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - if u.tryUpstreamServers(w, r, logger) { - return + ok, failures := u.tryUpstreamServers(w, r, logger) + if len(failures) > 0 { + u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger) + } + if !ok { + u.writeErrorResponse(w, r, logger) } - - u.writeErrorResponse(w, r, logger) } func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { @@ -136,7 +146,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { } } -func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool { +func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { timeout := u.upstreamTimeout if len(u.upstreamServers) > 1 { maxTotal := 5 * time.Second @@ -149,15 +159,19 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M } } + var failures []upstreamFailure for _, upstream := range u.upstreamServers { - if u.queryUpstream(w, r, upstream, timeout, logger) { - return true + if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil { + failures = append(failures, *failure) + } else { + return true, failures } } - return false + return false, failures } -func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool { +// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream. +func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { var rm *dns.Msg var t time.Duration var err error @@ -171,31 +185,32 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u }() if err != nil { - u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger) - return false + return u.handleUpstreamError(err, upstream, startTime) } if rm == nil || !rm.Response { - logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) - return false + return &upstreamFailure{upstream: upstream, reason: "no response"} } - return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) + if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { + return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} + } + + u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) + return nil } -func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) { +func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure { if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { - logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err) - return + return &upstreamFailure{upstream: upstream, reason: err.Error()} } elapsed := time.Since(startTime) - timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout) + reason := fmt.Sprintf("timeout after %v", elapsed.Truncate(time.Millisecond)) if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" { - timeoutMsg += " " + peerInfo + reason += " " + peerInfo } - timeoutMsg += fmt.Sprintf(" - error: %v", err) - logger.Warn(timeoutMsg) + return &upstreamFailure{upstream: upstream, reason: reason} } func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { @@ -215,16 +230,34 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn return true } -func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) { - logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) +func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) { + totalUpstreams := len(u.upstreamServers) + failedCount := len(failures) + failureSummary := formatFailures(failures) + if succeeded { + logger.Warnf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary) + } else { + logger.Errorf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary) + } +} + +func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) { m := new(dns.Msg) m.SetRcode(r, dns.RcodeServerFailure) if err := w.WriteMsg(m); err != nil { - logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) + logger.Errorf("write error response for domain=%s: %s", r.Question[0].Name, err) } } +func formatFailures(failures []upstreamFailure) string { + parts := make([]string, 0, len(failures)) + for _, f := range failures { + parts = append(parts, fmt.Sprintf("%s=%s", f.upstream, f.reason)) + } + return strings.Join(parts, ", ") +} + // ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work func (u *upstreamResolverBase) ProbeAvailability() { @@ -468,7 +501,6 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst return reply, nil } - // FormatPeerStatus formats peer connection status information for debugging DNS timeouts func FormatPeerStatus(peerState *peer.State) string { isConnected := peerState.ConnStatus == peer.StatusConnected diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 2852f4775..8b06e4475 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -2,6 +2,7 @@ package dns import ( "context" + "fmt" "net" "net/netip" "strings" @@ -9,6 +10,8 @@ import ( "time" "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface/device" @@ -140,6 +143,23 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) return c.r, c.rtt, c.err } +type mockUpstreamResponse struct { + msg *dns.Msg + err error +} + +type mockUpstreamResolverPerServer struct { + responses map[string]mockUpstreamResponse + rtt time.Duration +} + +func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { + if r, ok := c.responses[upstream]; ok { + return r.msg, c.rtt, r.err + } + return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream) +} + func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { mockClient := &mockUpstreamResolver{ err: dns.ErrTime, @@ -191,3 +211,267 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { t.Errorf("should be enabled") } } + +func TestUpstreamResolver_Failover(t *testing.T) { + upstream1 := netip.MustParseAddrPort("192.0.2.1:53") + upstream2 := netip.MustParseAddrPort("192.0.2.2:53") + + successAnswer := "192.0.2.100" + timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")} + + testCases := []struct { + name string + upstream1 mockUpstreamResponse + upstream2 mockUpstreamResponse + expectedRcode int + expectAnswer bool + expectTrySecond bool + }{ + { + name: "success on first upstream", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeSuccess, + expectAnswer: true, + expectTrySecond: false, + }, + { + name: "SERVFAIL from first should try second", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeSuccess, + expectAnswer: true, + expectTrySecond: true, + }, + { + name: "REFUSED from first should try second", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeSuccess, + expectAnswer: true, + expectTrySecond: true, + }, + { + name: "NXDOMAIN from first should NOT try second", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeNameError, "")}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeNameError, + expectAnswer: false, + expectTrySecond: false, + }, + { + name: "timeout from first should try second", + upstream1: mockUpstreamResponse{err: timeoutErr}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeSuccess, + expectAnswer: true, + expectTrySecond: true, + }, + { + name: "no response from first should try second", + upstream1: mockUpstreamResponse{msg: nil}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeSuccess, + expectAnswer: true, + expectTrySecond: true, + }, + { + name: "both upstreams return SERVFAIL", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + expectedRcode: dns.RcodeServerFailure, + expectAnswer: false, + expectTrySecond: true, + }, + { + name: "both upstreams timeout", + upstream1: mockUpstreamResponse{err: timeoutErr}, + upstream2: mockUpstreamResponse{err: timeoutErr}, + expectedRcode: dns.RcodeServerFailure, + expectAnswer: false, + expectTrySecond: true, + }, + { + name: "first SERVFAIL then timeout", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + upstream2: mockUpstreamResponse{err: timeoutErr}, + expectedRcode: dns.RcodeServerFailure, + expectAnswer: false, + expectTrySecond: true, + }, + { + name: "first timeout then SERVFAIL", + upstream1: mockUpstreamResponse{err: timeoutErr}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + expectedRcode: dns.RcodeServerFailure, + expectAnswer: false, + expectTrySecond: true, + }, + { + name: "first REFUSED then SERVFAIL", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + expectedRcode: dns.RcodeServerFailure, + expectAnswer: false, + expectTrySecond: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var queriedUpstreams []string + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + upstream1.String(): tc.upstream1, + upstream2.String(): tc.upstream2, + }, + rtt: time.Millisecond, + } + + trackingClient := &trackingMockClient{ + inner: mockClient, + queriedUpstreams: &queriedUpstreams, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: trackingClient, + upstreamServers: []netip.AddrPort{upstream1, upstream2}, + upstreamTimeout: UpstreamTimeout, + } + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + resolver.ServeDNS(responseWriter, inputMSG) + + require.NotNil(t, responseMSG, "should write a response") + assert.Equal(t, tc.expectedRcode, responseMSG.Rcode, "unexpected rcode") + + if tc.expectAnswer { + require.NotEmpty(t, responseMSG.Answer, "expected answer records") + assert.Contains(t, responseMSG.Answer[0].String(), successAnswer) + } + + if tc.expectTrySecond { + assert.Len(t, queriedUpstreams, 2, "should have tried both upstreams") + assert.Equal(t, upstream1.String(), queriedUpstreams[0]) + assert.Equal(t, upstream2.String(), queriedUpstreams[1]) + } else { + assert.Len(t, queriedUpstreams, 1, "should have only tried first upstream") + assert.Equal(t, upstream1.String(), queriedUpstreams[0]) + } + }) + } +} + +type trackingMockClient struct { + inner *mockUpstreamResolverPerServer + queriedUpstreams *[]string +} + +func (t *trackingMockClient) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) { + *t.queriedUpstreams = append(*t.queriedUpstreams, upstream) + return t.inner.exchange(ctx, upstream, r) +} + +func buildMockResponse(rcode int, answer string) *dns.Msg { + m := new(dns.Msg) + m.Response = true + m.Rcode = rcode + + if rcode == dns.RcodeSuccess && answer != "" { + m.Answer = []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "example.com.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.ParseIP(answer), + }, + } + } + return m +} + +func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) { + upstream := netip.MustParseAddrPort("192.0.2.1:53") + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + upstream.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamServers: []netip.AddrPort{upstream}, + upstreamTimeout: UpstreamTimeout, + } + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + resolver.ServeDNS(responseWriter, inputMSG) + + require.NotNil(t, responseMSG, "should write a response") + assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL") +} + +func TestFormatFailures(t *testing.T) { + testCases := []struct { + name string + failures []upstreamFailure + expected string + }{ + { + name: "empty slice", + failures: []upstreamFailure{}, + expected: "", + }, + { + name: "single failure", + failures: []upstreamFailure{ + {upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"}, + }, + expected: "8.8.8.8:53=SERVFAIL", + }, + { + name: "multiple failures", + failures: []upstreamFailure{ + {upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"}, + {upstream: netip.MustParseAddrPort("8.8.4.4:53"), reason: "timeout after 2s"}, + }, + expected: "8.8.8.8:53=SERVFAIL, 8.8.4.4:53=timeout after 2s", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := formatFailures(tc.failures) + assert.Equal(t, tc.expected, result) + }) + } +} diff --git a/client/internal/engine.go b/client/internal/engine.go index d3aeb0fa6..ab3d9984c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -516,6 +516,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("up wg interface: %w", err) } + // Set up notrack rules immediately after proxy is listening to prevent + // conntrack entries from being created before the rules are in place + e.setupWGProxyNoTrack() + // Set the WireGuard interface for rosenpass after interface is up if e.rpManager != nil { e.rpManager.SetInterface(e.wgInterface) @@ -628,6 +632,23 @@ func (e *Engine) initFirewall() error { return nil } +// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic. +// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy. +func (e *Engine) setupWGProxyNoTrack() { + if e.firewall == nil { + return + } + + proxyPort := e.wgInterface.GetProxyPort() + if proxyPort == 0 { + return + } + + if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil { + log.Warnf("failed to setup ebpf proxy notrack: %v", err) + } +} + func (e *Engine) blockLanAccess() { if e.config.BlockInbound { // no need to set up extra deny rules if inbound is already blocked in general @@ -1062,6 +1083,9 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR SyncResponse: syncResponse, LogPath: e.config.LogPath, ClientMetrics: e.clientMetrics, + RefreshStatus: func() { + e.RunHealthProbes(true) + }, } bundleJobParams := debug.BundleConfig{ @@ -1654,6 +1678,7 @@ func (e *Engine) parseNATExternalIPMappings() []string { func (e *Engine) close() { log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) + if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) @@ -1845,7 +1870,7 @@ func (e *Engine) getRosenpassAddr() string { return "" } -// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services +// RunHealthProbes executes health checks for Signal, Management, Relay, and WireGuard services // and updates the status recorder with the latest states. func (e *Engine) RunHealthProbes(waitForResult bool) bool { e.syncMsgMux.Lock() @@ -1859,23 +1884,8 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool { stuns := slices.Clone(e.STUNs) turns := slices.Clone(e.TURNs) - if e.wgInterface != nil { - stats, err := e.wgInterface.GetStats() - if err != nil { - log.Warnf("failed to get wireguard stats: %v", err) - e.syncMsgMux.Unlock() - return false - } - for _, key := range e.peerStore.PeersPubKey() { - // wgStats could be zero value, in which case we just reset the stats - wgStats, ok := stats[key] - if !ok { - continue - } - if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil { - log.Debugf("failed to update wg stats for peer %s: %s", key, err) - } - } + if err := e.statusRecorder.RefreshWireGuardStats(); err != nil { + log.Debugf("failed to refresh WireGuard stats: %v", err) } e.syncMsgMux.Unlock() diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index af9f27a71..012c8ad6e 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -107,6 +107,7 @@ type MockWGIface struct { GetStatsFunc func() (map[string]configurer.WGStats, error) GetInterfaceGUIDStringFunc func() (string, error) GetProxyFunc func() wgproxy.Proxy + GetProxyPortFunc func() uint16 GetNetFunc func() *netstack.Net LastActivitiesFunc func() map[string]monotime.Time } @@ -203,6 +204,13 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy { return m.GetProxyFunc() } +func (m *MockWGIface) GetProxyPort() uint16 { + if m.GetProxyPortFunc != nil { + return m.GetProxyPortFunc() + } + return 0 +} + func (m *MockWGIface) GetNet() *netstack.Net { return m.GetNetFunc() } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index f8a433a6e..39e9bacfa 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -28,6 +28,7 @@ type wgIfaceBase interface { Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy + GetProxyPort() uint16 UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemoveEndpointAddress(key string) error RemovePeer(peerKey string) error diff --git a/client/internal/login.go b/client/internal/login.go deleted file mode 100644 index f528783ef..000000000 --- a/client/internal/login.go +++ /dev/null @@ -1,201 +0,0 @@ -package internal - -import ( - "context" - "net/url" - - "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/netbirdio/netbird/client/internal/profilemanager" - "github.com/netbirdio/netbird/client/ssh" - "github.com/netbirdio/netbird/client/system" - mgm "github.com/netbirdio/netbird/shared/management/client" - mgmProto "github.com/netbirdio/netbird/shared/management/proto" -) - -// IsLoginRequired check that the server is support SSO or not -func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) { - mgmURL := config.ManagementURL - mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL) - if err != nil { - return false, err - } - defer func() { - err = mgmClient.Close() - if err != nil { - cStatus, ok := status.FromError(err) - if !ok || ok && cStatus.Code() != codes.Canceled { - log.Warnf("failed to close the Management service client, err: %v", err) - } - } - }() - log.Debugf("connected to the Management service %s", mgmURL.String()) - - pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey)) - if err != nil { - return false, err - } - - _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) - if isLoginNeeded(err) { - return true, nil - } - return false, err -} - -// Login or register the client -func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error { - mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL) - if err != nil { - return err - } - defer func() { - err = mgmClient.Close() - if err != nil { - cStatus, ok := status.FromError(err) - if !ok || ok && cStatus.Code() != codes.Canceled { - log.Warnf("failed to close the Management service client, err: %v", err) - } - } - }() - log.Debugf("connected to the Management service %s", config.ManagementURL.String()) - - pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey)) - if err != nil { - return err - } - - serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) - if serverKey != nil && isRegistrationNeeded(err) { - log.Debugf("peer registration required") - _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) - if err != nil { - return err - } - } else if err != nil { - return err - } - - return nil -} - -func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { - // validate our peer's Wireguard PRIVATE key - myPrivateKey, err := wgtypes.ParseKey(privateKey) - if err != nil { - log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error()) - return nil, err - } - - var mgmTlsEnabled bool - if mgmURL.Scheme == "https" { - mgmTlsEnabled = true - } - - log.Debugf("connecting to the Management service %s", mgmURL.String()) - mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled) - if err != nil { - log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err) - return nil, err - } - return mgmClient, err -} - -func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) { - serverKey, err := mgmClient.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, nil, err - } - - sysInfo := system.GetInfo(ctx) - sysInfo.SetFlags( - config.RosenpassEnabled, - config.RosenpassPermissive, - config.ServerSSHAllowed, - config.DisableClientRoutes, - config.DisableServerRoutes, - config.DisableDNS, - config.DisableFirewall, - config.BlockLANAccess, - config.BlockInbound, - config.LazyConnectionEnabled, - config.EnableSSHRoot, - config.EnableSSHSFTP, - config.EnableSSHLocalPortForwarding, - config.EnableSSHRemotePortForwarding, - config.DisableSSHAuth, - ) - loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) - return serverKey, loginResp, err -} - -// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. -// Otherwise tries to register with the provided setupKey via command line. -func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) { - validSetupKey, err := uuid.Parse(setupKey) - if err != nil && jwtToken == "" { - return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err) - } - - log.Debugf("sending peer registration request to Management Service") - info := system.GetInfo(ctx) - info.SetFlags( - config.RosenpassEnabled, - config.RosenpassPermissive, - config.ServerSSHAllowed, - config.DisableClientRoutes, - config.DisableServerRoutes, - config.DisableDNS, - config.DisableFirewall, - config.BlockLANAccess, - config.BlockInbound, - config.LazyConnectionEnabled, - config.EnableSSHRoot, - config.EnableSSHSFTP, - config.EnableSSHLocalPortForwarding, - config.EnableSSHRemotePortForwarding, - config.DisableSSHAuth, - ) - loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) - if err != nil { - log.Errorf("failed registering peer %v", err) - return nil, err - } - - log.Infof("peer has been successfully registered on Management Service") - - return loginResp, nil -} - -func isLoginNeeded(err error) bool { - if err == nil { - return false - } - s, ok := status.FromError(err) - if !ok { - return false - } - if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied { - return true - } - return false -} - -func isRegistrationNeeded(err error) bool { - if err == nil { - return false - } - s, ok := status.FromError(err) - if !ok { - return false - } - if s.Code() == codes.PermissionDenied { - return true - } - return false -} diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 697bda2ff..abedc208e 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -1145,6 +1145,38 @@ func (d *Status) PeersStatus() (*configurer.Stats, error) { return d.wgIface.FullStats() } +// RefreshWireGuardStats fetches fresh WireGuard statistics from the interface +// and updates the cached peer states. This ensures accurate handshake times and +// transfer statistics in status reports without running full health probes. +func (d *Status) RefreshWireGuardStats() error { + d.mux.Lock() + defer d.mux.Unlock() + + if d.wgIface == nil { + return nil // silently skip if interface not set + } + + stats, err := d.wgIface.FullStats() + if err != nil { + return fmt.Errorf("get wireguard stats: %w", err) + } + + // Update each peer's WireGuard statistics + for _, peerStats := range stats.Peers { + peerState, ok := d.peers[peerStats.PublicKey] + if !ok { + continue + } + + peerState.LastWireguardHandshake = peerStats.LastHandshake + peerState.BytesRx = peerStats.RxBytes + peerState.BytesTx = peerStats.TxBytes + d.peers[peerStats.PublicKey] = peerState + } + + return nil +} + type EventQueue struct { maxSize int events []*proto.SystemEvent diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 840fc9241..b6b9d2cf4 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "strconv" "sync" "time" @@ -286,8 +287,8 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent RosenpassAddr: remoteOfferAnswer.RosenpassAddr, LocalIceCandidateType: pair.Local.Type().String(), RemoteIceCandidateType: pair.Remote.Type().String(), - LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()), - RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()), + LocalIceCandidateEndpoint: net.JoinHostPort(pair.Local.Address(), strconv.Itoa(pair.Local.Port())), + RemoteIceCandidateEndpoint: net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(pair.Remote.Port())), Relayed: isRelayed(pair), RelayedOnLocal: isRelayCandidate(pair.Local), } @@ -328,13 +329,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { // wait local endpoint configuration time.Sleep(time.Second) - addrString := pair.Remote.Address() - parsed, err := netip.ParseAddr(addrString) - if (err == nil) && (parsed.Is6()) { - addrString = fmt.Sprintf("[%s]", addrString) - //IPv6 Literals need to be wrapped in brackets for Resolve*Addr() - } - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort)) + addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(remoteWgPort))) if err != nil { w.log.Warnf("got an error while resolving the udp address, err: %s", err) return @@ -386,12 +381,44 @@ func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, } } +func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) { + sessionID := w.SessionID() + stats := agent.GetCandidatePairsStats() + localCandidates, _ := agent.GetLocalCandidates() + remoteCandidates, _ := agent.GetRemoteCandidates() + + localMap := make(map[string]ice.Candidate) + for _, c := range localCandidates { + localMap[c.ID()] = c + } + remoteMap := make(map[string]ice.Candidate) + for _, c := range remoteCandidates { + remoteMap[c.ID()] = c + } + + for _, stat := range stats { + if stat.State == ice.CandidatePairStateSucceeded { + local, lok := localMap[stat.LocalCandidateID] + remote, rok := remoteMap[stat.RemoteCandidateID] + if !lok || !rok { + continue + } + w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms", + sessionID, + local.NetworkType(), local.Type(), local.Address(), + remote.NetworkType(), remote.Type(), remote.Address(), + stat.CurrentRoundTripTime*1000) + } + } +} + func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) { return func(state ice.ConnectionState) { w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) switch state { case ice.ConnectionStateConnected: w.lastKnownState = ice.ConnectionStateConnected + w.logSuccessfulPaths(agent) return case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed: // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to diff --git a/client/internal/pkce_auth.go b/client/internal/pkce_auth.go deleted file mode 100644 index 23c92e8af..000000000 --- a/client/internal/pkce_auth.go +++ /dev/null @@ -1,138 +0,0 @@ -package internal - -import ( - "context" - "crypto/tls" - "fmt" - "net/url" - - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - mgm "github.com/netbirdio/netbird/shared/management/client" - "github.com/netbirdio/netbird/shared/management/client/common" -) - -// PKCEAuthorizationFlow represents PKCE Authorization Flow information -type PKCEAuthorizationFlow struct { - ProviderConfig PKCEAuthProviderConfig -} - -// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow -type PKCEAuthProviderConfig struct { - // ClientID An IDP application client id - ClientID string - // ClientSecret An IDP application client secret - ClientSecret string - // Audience An Audience for to authorization validation - Audience string - // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token - TokenEndpoint string - // AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code - AuthorizationEndpoint string - // Scopes provides the scopes to be included in the token request - Scope string - // RedirectURL handles authorization code from IDP manager - RedirectURLs []string - // UseIDToken indicates if the id token should be used for authentication - UseIDToken bool - // ClientCertPair is used for mTLS authentication to the IDP - ClientCertPair *tls.Certificate - // DisablePromptLogin makes the PKCE flow to not prompt the user for login - DisablePromptLogin bool - // LoginFlag is used to configure the PKCE flow login behavior - LoginFlag common.LoginFlag - // LoginHint is used to pre-fill the email/username field during authentication - LoginHint string -} - -// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it -func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) { - // validate our peer's Wireguard PRIVATE key - myPrivateKey, err := wgtypes.ParseKey(privateKey) - if err != nil { - log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error()) - return PKCEAuthorizationFlow{}, err - } - - var mgmTLSEnabled bool - if mgmURL.Scheme == "https" { - mgmTLSEnabled = true - } - - log.Debugf("connecting to Management Service %s", mgmURL.String()) - mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled) - if err != nil { - log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err) - return PKCEAuthorizationFlow{}, err - } - log.Debugf("connected to the Management service %s", mgmURL.String()) - - defer func() { - err = mgmClient.Close() - if err != nil { - log.Warnf("failed to close the Management service client %v", err) - } - }() - - serverKey, err := mgmClient.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return PKCEAuthorizationFlow{}, err - } - - protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey) - if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { - log.Warnf("server couldn't find pkce flow, contact admin: %v", err) - return PKCEAuthorizationFlow{}, err - } - log.Errorf("failed to retrieve pkce flow: %v", err) - return PKCEAuthorizationFlow{}, err - } - - authFlow := PKCEAuthorizationFlow{ - ProviderConfig: PKCEAuthProviderConfig{ - Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(), - ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(), - ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(), - TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(), - AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(), - Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(), - RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(), - UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(), - ClientCertPair: clientCert, - DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(), - LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()), - }, - } - - err = isPKCEProviderConfigValid(authFlow.ProviderConfig) - if err != nil { - return PKCEAuthorizationFlow{}, err - } - - return authFlow, nil -} - -func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error { - errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" - if config.ClientID == "" { - return fmt.Errorf(errorMSGFormat, "Client ID") - } - if config.TokenEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Token Endpoint") - } - if config.AuthorizationEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint") - } - if config.Scope == "" { - return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes") - } - if config.RedirectURLs == nil { - return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs") - } - return nil -} diff --git a/client/internal/rosenpass/manager.go b/client/internal/rosenpass/manager.go index 26a1eef58..1faa22dc5 100644 --- a/client/internal/rosenpass/manager.go +++ b/client/internal/rosenpass/manager.go @@ -17,6 +17,11 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +const ( + defaultLog = slog.LevelInfo + defaultLogLevelVar = "NB_ROSENPASS_LOG_LEVEL" +) + func hashRosenpassKey(key []byte) string { hasher := sha256.New() hasher.Write(key) @@ -45,7 +50,7 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error) } rpKeyHash := hashRosenpassKey(public) - log.Debugf("generated new rosenpass key pair with public key %s", rpKeyHash) + log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash) return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil } @@ -101,7 +106,7 @@ func (m *Manager) removePeer(wireGuardPubKey string) error { func (m *Manager) generateConfig() (rp.Config, error) { opts := &slog.HandlerOptions{ - Level: slog.LevelDebug, + Level: getLogLevel(), } logger := slog.New(slog.NewTextHandler(os.Stdout, opts)) cfg := rp.Config{Logger: logger} @@ -133,6 +138,26 @@ func (m *Manager) generateConfig() (rp.Config, error) { return cfg, nil } +func getLogLevel() slog.Level { + level, ok := os.LookupEnv(defaultLogLevelVar) + if !ok { + return defaultLog + } + switch strings.ToLower(level) { + case "debug": + return slog.LevelDebug + case "info": + return slog.LevelInfo + case "warn": + return slog.LevelWarn + case "error": + return slog.LevelError + default: + log.Warnf("unknown log level: %s. Using default %s", level, defaultLog.String()) + return defaultLog + } +} + func (m *Manager) OnDisconnected(peerKey string) { m.lock.Lock() defer m.lock.Unlock() diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 935910fc9..aafef41d3 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -263,7 +263,14 @@ func (c *Client) IsLoginRequired() bool { return true } - needsLogin, err := internal.IsLoginRequired(ctx, cfg) + authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg) + if err != nil { + log.Errorf("IsLoginRequired: failed to create auth client: %v", err) + return true // Assume login is required if we can't create auth client + } + defer authClient.Close() + + needsLogin, err := authClient.IsLoginRequired(ctx) if err != nil { log.Errorf("IsLoginRequired: check failed: %v", err) // If the check fails, assume login is required to be safe @@ -314,16 +321,19 @@ func (c *Client) LoginForMobile() string { // This could cause a potential race condition with loading the extension which need to be handled on swift side go func() { - waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second - waitCTX, cancel := context.WithTimeout(ctx, waitTimeout) - defer cancel() - tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) + tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo) if err != nil { log.Errorf("LoginForMobile: WaitToken failed: %v", err) return } jwtToken := tokenInfo.GetTokenToUse() - if err := internal.Login(ctx, cfg, "", jwtToken); err != nil { + authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg) + if err != nil { + log.Errorf("LoginForMobile: failed to create auth client: %v", err) + return + } + defer authClient.Close() + if err, _ := authClient.Login(ctx, "", jwtToken); err != nil { log.Errorf("LoginForMobile: Login failed: %v", err) return } diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go index 27fdcf5ef..9d447ef3f 100644 --- a/client/ios/NetBirdSDK/login.go +++ b/client/ios/NetBirdSDK/login.go @@ -7,13 +7,8 @@ import ( "fmt" "time" - "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - gstatus "google.golang.org/grpc/status" - "github.com/netbirdio/netbird/client/cmd" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" @@ -90,34 +85,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { } func (a *Auth) saveConfigIfSSOSupported() (bool, error) { - supportsSSO := true - err := a.withBackOff(a.ctx, func() (err error) { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { - _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - s, ok := gstatus.FromError(err) - if !ok { - return err - } - if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented { - supportsSSO = false - err = nil - } + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return false, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() - return err - } - - return err - }) + supportsSSO, err := authClient.IsSSOSupported(a.ctx) + if err != nil { + return false, fmt.Errorf("failed to check SSO support: %v", err) + } if !supportsSSO { return false, nil } - if err != nil { - return false, fmt.Errorf("backoff cycle failed: %v", err) - } - // Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename) // which are blocked by the tvOS sandbox in App Group containers err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config) @@ -141,19 +123,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK } func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + //nolint ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) - - err := a.withBackOff(a.ctx, func() error { - backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "") - if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) { - // we got an answer from management, exit backoff earlier - return backoff.Permanent(backoffErr) - } - return backoffErr - }) + err, _ = authClient.Login(ctxWithValues, setupKey, "") if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("login failed: %v", err) } // Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename) @@ -164,15 +144,16 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string // LoginSync performs a synchronous login check without UI interaction // Used for background VPN connection where user should already be authenticated func (a *Auth) LoginSync() error { - var needsLogin bool + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() // check if we need to generate JWT token - err := a.withBackOff(a.ctx, func() (err error) { - needsLogin, err = internal.IsLoginRequired(a.ctx, a.config) - return - }) + needsLogin, err := authClient.IsLoginRequired(a.ctx) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("failed to check login requirement: %v", err) } jwtToken := "" @@ -180,15 +161,12 @@ func (a *Auth) LoginSync() error { return fmt.Errorf("not authenticated") } - err = a.withBackOff(a.ctx, func() error { - err := internal.Login(a.ctx, a.config, "", jwtToken) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { - // PermissionDenied means registration is required or peer is blocked - return backoff.Permanent(err) - } - return err - }) + err, isAuthError := authClient.Login(a.ctx, "", jwtToken) if err != nil { + if isAuthError { + // PermissionDenied means registration is required or peer is blocked + return fmt.Errorf("authentication error: %v", err) + } return fmt.Errorf("login failed: %v", err) } @@ -225,8 +203,6 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen } func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error { - var needsLogin bool - // Create context with device name if provided ctx := a.ctx if deviceName != "" { @@ -234,33 +210,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) } - // check if we need to generate JWT token - err := a.withBackOff(ctx, func() (err error) { - needsLogin, err = internal.IsLoginRequired(ctx, a.config) - return - }) + authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + + // check if we need to generate JWT token + needsLogin, err := authClient.IsLoginRequired(ctx) + if err != nil { + return fmt.Errorf("failed to check login requirement: %v", err) } jwtToken := "" if needsLogin { - tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth) + tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } jwtToken = tokenInfo.GetTokenToUse() } - err = a.withBackOff(ctx, func() error { - err := internal.Login(ctx, a.config, "", jwtToken) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { - // PermissionDenied means registration is required or peer is blocked - return backoff.Permanent(err) - } - return err - }) + err, isAuthError := authClient.Login(ctx, "", jwtToken) if err != nil { + if isAuthError { + // PermissionDenied means registration is required or peer is blocked + return fmt.Errorf("authentication error: %v", err) + } return fmt.Errorf("login failed: %v", err) } @@ -285,10 +261,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin const authInfoRequestTimeout = 30 * time.Second -func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "") +func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) { + oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get OAuth flow: %v", err) } // Use a bounded timeout for the auth info request to prevent indefinite hangs @@ -313,15 +289,6 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) return &tokenInfo, nil } -func (a *Auth) withBackOff(ctx context.Context, bf func() error) error { - return backoff.RetryNotify( - bf, - backoff.WithContext(cmd.CLIBackOffSettings, ctx), - func(err error, duration time.Duration) { - log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err) - }) -} - // GetConfigJSON returns the current config as a JSON string. // This can be used by the caller to persist the config via alternative storage // mechanisms (e.g., UserDefaults on tvOS where file writes are blocked). diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 9cbe34e1d..1d9d7233c 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v6.33.1 +// protoc v6.32.1 // source: daemon.proto package proto @@ -5364,6 +5364,154 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 { return 0 } +// StartCPUProfileRequest for starting CPU profiling +type StartCPUProfileRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartCPUProfileRequest) Reset() { + *x = StartCPUProfileRequest{} + mi := &file_daemon_proto_msgTypes[79] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartCPUProfileRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartCPUProfileRequest) ProtoMessage() {} + +func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[79] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead. +func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{79} +} + +// StartCPUProfileResponse confirms CPU profiling has started +type StartCPUProfileResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartCPUProfileResponse) Reset() { + *x = StartCPUProfileResponse{} + mi := &file_daemon_proto_msgTypes[80] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartCPUProfileResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartCPUProfileResponse) ProtoMessage() {} + +func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[80] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead. +func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{80} +} + +// StopCPUProfileRequest for stopping CPU profiling +type StopCPUProfileRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopCPUProfileRequest) Reset() { + *x = StopCPUProfileRequest{} + mi := &file_daemon_proto_msgTypes[81] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopCPUProfileRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopCPUProfileRequest) ProtoMessage() {} + +func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[81] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead. +func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{81} +} + +// StopCPUProfileResponse confirms CPU profiling has stopped +type StopCPUProfileResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopCPUProfileResponse) Reset() { + *x = StopCPUProfileResponse{} + mi := &file_daemon_proto_msgTypes[82] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopCPUProfileResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopCPUProfileResponse) ProtoMessage() {} + +func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[82] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead. +func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{82} +} + type InstallerResultRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -5372,7 +5520,7 @@ type InstallerResultRequest struct { func (x *InstallerResultRequest) Reset() { *x = InstallerResultRequest{} - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[83] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5384,7 +5532,7 @@ func (x *InstallerResultRequest) String() string { func (*InstallerResultRequest) ProtoMessage() {} func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[79] + mi := &file_daemon_proto_msgTypes[83] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5397,7 +5545,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead. func (*InstallerResultRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{79} + return file_daemon_proto_rawDescGZIP(), []int{83} } type InstallerResultResponse struct { @@ -5410,7 +5558,7 @@ type InstallerResultResponse struct { func (x *InstallerResultResponse) Reset() { *x = InstallerResultResponse{} - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[84] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5422,7 +5570,7 @@ func (x *InstallerResultResponse) String() string { func (*InstallerResultResponse) ProtoMessage() {} func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[80] + mi := &file_daemon_proto_msgTypes[84] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5435,7 +5583,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead. func (*InstallerResultResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{80} + return file_daemon_proto_rawDescGZIP(), []int{84} } func (x *InstallerResultResponse) GetSuccess() bool { @@ -5462,7 +5610,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[86] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5474,7 +5622,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[82] + mi := &file_daemon_proto_msgTypes[86] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5994,6 +6142,10 @@ const file_daemon_proto_rawDesc = "" + "\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" + "\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" + "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" + + "\x16StartCPUProfileRequest\"\x19\n" + + "\x17StartCPUProfileResponse\"\x17\n" + + "\x15StopCPUProfileRequest\"\x18\n" + + "\x16StopCPUProfileResponse\"\x18\n" + "\x16InstallerResultRequest\"O\n" + "\x17InstallerResultResponse\x12\x18\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + @@ -6006,7 +6158,7 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\xb4\x13\n" + + "\x05TRACE\x10\a2\xdd\x14\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -6041,7 +6193,9 @@ const file_daemon_proto_rawDesc = "" + "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + - "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" + + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" + + "\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" + + "\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" + "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" + "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3" @@ -6058,7 +6212,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType @@ -6143,21 +6297,25 @@ var file_daemon_proto_goTypes = []any{ (*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse (*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest (*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse - (*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest - (*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse - nil, // 85: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range - nil, // 87: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 88: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp + (*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest + (*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse + (*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest + (*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse + (*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest + (*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse + nil, // 89: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range + nil, // 91: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 92: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ 1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType - 88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration 25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo 22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState @@ -6168,8 +6326,8 @@ var file_daemon_proto_depIdxs = []int32{ 57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent 26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState 33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule @@ -6180,10 +6338,10 @@ var file_daemon_proto_depIdxs = []int32{ 54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile 32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest @@ -6217,43 +6375,47 @@ var file_daemon_proto_depIdxs = []int32{ 77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest 79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest 81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest - 5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest - 83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest - 8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse - 84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse - 67, // [67:100] is the sub-list for method output_type - 34, // [34:67] is the sub-list for method input_type + 83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest + 85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest + 5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest + 87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse + 86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse + 6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse + 88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 69, // [69:104] is the sub-list for method output_type + 34, // [34:69] is the sub-list for method input_type 34, // [34:34] is the sub-list for extension type_name 34, // [34:34] is the sub-list for extension extendee 0, // [0:34] is the sub-list for field type_name @@ -6283,7 +6445,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 4, - NumMessages: 84, + NumMessages: 88, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 7a802d830..68b9a9348 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -94,6 +94,12 @@ service DaemonService { // WaitJWTToken waits for JWT authentication completion rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {} +// StartCPUProfile starts CPU profiling in the daemon + rpc StartCPUProfile(StartCPUProfileRequest) returns (StartCPUProfileResponse) {} + + // StopCPUProfile stops CPU profiling in the daemon + rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {} + rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {} @@ -776,6 +782,18 @@ message WaitJWTTokenResponse { int64 expiresIn = 3; } +// StartCPUProfileRequest for starting CPU profiling +message StartCPUProfileRequest {} + +// StartCPUProfileResponse confirms CPU profiling has started +message StartCPUProfileResponse {} + +// StopCPUProfileRequest for stopping CPU profiling +message StopCPUProfileRequest {} + +// StopCPUProfileResponse confirms CPU profiling has stopped +message StopCPUProfileResponse {} + message InstallerResultRequest { } diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index fdabb1879..ea9b4df05 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -70,6 +70,10 @@ type DaemonServiceClient interface { RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) // WaitJWTToken waits for JWT authentication completion WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) + // StartCPUProfile starts CPU profiling in the daemon + StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) + // StopCPUProfile stops CPU profiling in the daemon + StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) } @@ -384,6 +388,24 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken return out, nil } +func (c *daemonServiceClient) StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) { + out := new(StartCPUProfileResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/StartCPUProfile", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) { + out := new(StopCPUProfileResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/StopCPUProfile", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) { out := new(OSLifecycleResponse) err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...) @@ -458,6 +480,10 @@ type DaemonServiceServer interface { RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) // WaitJWTToken waits for JWT authentication completion WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) + // StartCPUProfile starts CPU profiling in the daemon + StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) + // StopCPUProfile stops CPU profiling in the daemon + StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) mustEmbedUnimplementedDaemonServiceServer() @@ -560,6 +586,12 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented") } +func (UnimplementedDaemonServiceServer) StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method StartCPUProfile not implemented") +} +func (UnimplementedDaemonServiceServer) StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method StopCPUProfile not implemented") +} func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented") } @@ -1140,6 +1172,42 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d return interceptor(ctx, in, info, handler) } +func _DaemonService_StartCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StartCPUProfileRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StartCPUProfile(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/StartCPUProfile", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StartCPUProfile(ctx, req.(*StartCPUProfileRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_StopCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StopCPUProfileRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StopCPUProfile(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/StopCPUProfile", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StopCPUProfile(ctx, req.(*StopCPUProfileRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(OSLifecycleRequest) if err := dec(in); err != nil { @@ -1303,6 +1371,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "WaitJWTToken", Handler: _DaemonService_WaitJWTToken_Handler, }, + { + MethodName: "StartCPUProfile", + Handler: _DaemonService_StartCPUProfile_Handler, + }, + { + MethodName: "StopCPUProfile", + Handler: _DaemonService_StopCPUProfile_Handler, + }, { MethodName: "NotifyOSLifecycle", Handler: _DaemonService_NotifyOSLifecycle_Handler, diff --git a/client/server/debug.go b/client/server/debug.go index d3f27af55..e563f79c1 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -3,9 +3,11 @@ package server import ( + "bytes" "context" "errors" "fmt" + "runtime/pprof" log "github.com/sirupsen/logrus" @@ -31,12 +33,34 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( } } + var cpuProfileData []byte + if s.cpuProfileBuf != nil && !s.cpuProfiling { + cpuProfileData = s.cpuProfileBuf.Bytes() + defer func() { + s.cpuProfileBuf = nil + }() + } + + // Prepare refresh callback for health probes + var refreshStatus func() + if s.connectClient != nil { + engine := s.connectClient.Engine() + if engine != nil { + refreshStatus = func() { + log.Debug("refreshing system health status for debug bundle") + engine.RunHealthProbes(true) + } + } + } + bundleGenerator := debug.NewBundleGenerator( debug.GeneratorDependencies{ InternalConfig: s.config, StatusRecorder: s.statusRecorder, SyncResponse: syncResponse, LogPath: s.logFile, + CPUProfile: cpuProfileData, + RefreshStatus: refreshStatus, ClientMetrics: clientMetrics, }, debug.BundleConfig{ @@ -117,3 +141,43 @@ func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) { return cClient.GetLatestSyncResponse() } + +// StartCPUProfile starts CPU profiling in the daemon. +func (s *Server) StartCPUProfile(_ context.Context, _ *proto.StartCPUProfileRequest) (*proto.StartCPUProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.cpuProfiling { + return nil, fmt.Errorf("CPU profiling already in progress") + } + + s.cpuProfileBuf = &bytes.Buffer{} + s.cpuProfiling = true + if err := pprof.StartCPUProfile(s.cpuProfileBuf); err != nil { + s.cpuProfileBuf = nil + s.cpuProfiling = false + return nil, fmt.Errorf("start CPU profile: %w", err) + } + + log.Info("CPU profiling started") + return &proto.StartCPUProfileResponse{}, nil +} + +// StopCPUProfile stops CPU profiling in the daemon. +func (s *Server) StopCPUProfile(_ context.Context, _ *proto.StopCPUProfileRequest) (*proto.StopCPUProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if !s.cpuProfiling { + return nil, fmt.Errorf("CPU profiling not in progress") + } + + pprof.StopCPUProfile() + s.cpuProfiling = false + + if s.cpuProfileBuf != nil { + log.Infof("CPU profiling stopped, captured %d bytes", s.cpuProfileBuf.Len()) + } + + return &proto.StopCPUProfileResponse{}, nil +} diff --git a/client/server/server.go b/client/server/server.go index 408bd56db..108eab9fe 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "context" "errors" "fmt" @@ -77,6 +78,9 @@ type Server struct { persistSyncResponse bool isSessionActive atomic.Bool + cpuProfileBuf *bytes.Buffer + cpuProfiling bool + profileManager *profilemanager.ServiceManager profilesDisabled bool updateSettingsDisabled bool @@ -249,10 +253,17 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil // loginAttempt attempts to login using the provided information. it returns a status in case something fails func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) { - var status internal.StatusType - err := internal.Login(ctx, s.config, setupKey, jwtToken) + authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config) if err != nil { - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { + log.Errorf("failed to create auth client: %v", err) + return internal.StatusLoginFailed, err + } + defer authClient.Close() + + var status internal.StatusType + err, isAuthError := authClient.Login(ctx, setupKey, jwtToken) + if err != nil { + if isAuthError { log.Warnf("failed login: %v", err) status = internal.StatusNeedsLogin } else { @@ -577,8 +588,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin s.oauthAuthFlow.waitCancel() } - waitTimeout := time.Until(s.oauthAuthFlow.expiresAt) - waitCTX, cancel := context.WithTimeout(ctx, waitTimeout) + waitCTX, cancel := context.WithCancel(ctx) defer cancel() s.mutex.Lock() @@ -1323,6 +1333,10 @@ func (s *Server) runProbes(waitForProbeResult bool) { if engine.RunHealthProbes(waitForProbeResult) { s.lastProbe = time.Now() } + } else { + if err := s.statusRecorder.RefreshWireGuardStats(); err != nil { + log.Debugf("failed to refresh WireGuard stats: %v", err) + } } } diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index cb1c36e13..8897b9c7e 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -207,8 +207,6 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) { } func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) { - // Create a backend session to mirror the client's session request. - // This keeps the connection alive on the server side while port forwarding channels operate. serverSession, err := sshClient.NewSession() if err != nil { _, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err) @@ -216,10 +214,28 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c } defer func() { _ = serverSession.Close() }() - <-session.Context().Done() + serverSession.Stdin = session + serverSession.Stdout = session + serverSession.Stderr = session.Stderr() - if err := session.Exit(0); err != nil { - log.Debugf("session exit: %v", err) + if err := serverSession.Shell(); err != nil { + log.Debugf("start shell: %v", err) + return + } + + done := make(chan error, 1) + go func() { + done <- serverSession.Wait() + }() + + select { + case <-session.Context().Done(): + return + case err := <-done: + if err != nil { + log.Debugf("shell session: %v", err) + p.handleProxyExitCode(session, err) + } } } diff --git a/client/ssh/server/command_execution.go b/client/ssh/server/command_execution.go index 7a01ce4f6..b0a85fe4b 100644 --- a/client/ssh/server/command_execution.go +++ b/client/ssh/server/command_execution.go @@ -12,8 +12,8 @@ import ( log "github.com/sirupsen/logrus" ) -// handleCommand executes an SSH command with privilege validation -func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) { +// handleExecution executes an SSH command or shell with privilege validation +func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) { hasPty := winCh != nil commandType := "command" @@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command())) - execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty) + execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty) if err != nil { logger.Errorf("%s creation failed: %v", commandType, err) @@ -51,13 +51,12 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege defer cleanup() - ptyReq, _, _ := session.Pty() if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) { logger.Debugf("%s execution completed", commandType) } } -func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) { +func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) { localUser := privilegeResult.User if localUser == nil { return nil, nil, errors.New("no user in privilege result") @@ -66,28 +65,28 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh // If PTY requested but su doesn't support --pty, skip su and use executor // This ensures PTY functionality is provided (executor runs within our allocated PTY) if hasPty && !s.suSupportsPty { - log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality") - cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality") + cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty) if err != nil { return nil, nil, fmt.Errorf("create command with privileges: %w", err) } - cmd.Env = s.prepareCommandEnv(localUser, session) + cmd.Env = s.prepareCommandEnv(logger, localUser, session) return cmd, cleanup, nil } // Try su first for system integration (PAM/audit) when privileged - cmd, err := s.createSuCommand(session, localUser, hasPty) + cmd, err := s.createSuCommand(logger, session, localUser, hasPty) if err != nil || privilegeResult.UsedFallback { - log.Debugf("su command failed, falling back to executor: %v", err) - cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + logger.Debugf("su command failed, falling back to executor: %v", err) + cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty) if err != nil { return nil, nil, fmt.Errorf("create command with privileges: %w", err) } - cmd.Env = s.prepareCommandEnv(localUser, session) + cmd.Env = s.prepareCommandEnv(logger, localUser, session) return cmd, cleanup, nil } - cmd.Env = s.prepareCommandEnv(localUser, session) + cmd.Env = s.prepareCommandEnv(logger, localUser, session) return cmd, func() {}, nil } diff --git a/client/ssh/server/command_execution_js.go b/client/ssh/server/command_execution_js.go index 01759a337..3aeaa135c 100644 --- a/client/ssh/server/command_execution_js.go +++ b/client/ssh/server/command_execution_js.go @@ -15,17 +15,17 @@ import ( var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform") // createSuCommand is not supported on JS/WASM -func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) { +func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) { return nil, errNotSupported } // createExecutorCommand is not supported on JS/WASM -func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) { +func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) { return nil, nil, errNotSupported } // prepareCommandEnv is not supported on JS/WASM -func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string { +func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string { return nil } diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go index db1a9bcfe..279b89341 100644 --- a/client/ssh/server/command_execution_unix.go +++ b/client/ssh/server/command_execution_unix.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "runtime" "strings" "sync" @@ -99,40 +100,52 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool { return isUtilLinux } -// createSuCommand creates a command using su -l -c for privilege switching -func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { +// createSuCommand creates a command using su - for privilege switching. +func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { + if err := validateUsername(localUser.Username); err != nil { + return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) + } + suPath, err := exec.LookPath("su") if err != nil { return nil, fmt.Errorf("su command not available: %w", err) } - command := session.RawCommand() - if command == "" { - return nil, fmt.Errorf("no command specified for su execution") - } - - args := []string{"-l"} + args := []string{"-"} if hasPty && s.suSupportsPty { args = append(args, "--pty") } - args = append(args, localUser.Username, "-c", command) + args = append(args, localUser.Username) + command := session.RawCommand() + if command != "" { + args = append(args, "-c", command) + } + + logger.Debugf("creating su command: %s %v", suPath, args) cmd := exec.CommandContext(session.Context(), suPath, args...) cmd.Dir = localUser.HomeDir return cmd, nil } -// getShellCommandArgs returns the shell command and arguments for executing a command string +// getShellCommandArgs returns the shell command and arguments for executing a command string. func (s *Server) getShellCommandArgs(shell, cmdString string) []string { if cmdString == "" { - return []string{shell, "-l"} + return []string{shell} } - return []string{shell, "-l", "-c", cmdString} + return []string{shell, "-c", cmdString} +} + +// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname". +func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd { + cmd := exec.CommandContext(ctx, shell, args[1:]...) + cmd.Args[0] = "-" + filepath.Base(shell) + return cmd } // prepareCommandEnv prepares environment variables for command execution on Unix -func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { +func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string { env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) env = append(env, prepareSSHEnv(session)...) for _, v := range session.Environ() { @@ -154,7 +167,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) } -func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { +func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session) if err != nil { logger.Errorf("Pty command creation failed: %v", err) @@ -244,11 +257,6 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty }() go func() { - defer func() { - if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { - logger.Debugf("session close error: %v", err) - } - }() if _, err := io.Copy(session, ptmx); err != nil { if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { logger.Warnf("Pty output copy error: %v", err) @@ -268,7 +276,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex case <-ctx.Done(): s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done) case err := <-done: - s.handlePtyCommandCompletion(logger, session, err) + s.handlePtyCommandCompletion(logger, session, ptyMgr, err) } } @@ -296,17 +304,20 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses } } -func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) { +func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) { if err != nil { logger.Debugf("Pty command execution failed: %v", err) s.handleSessionExit(session, err, logger) - return + } else { + logger.Debugf("Pty command completed successfully") + if err := session.Exit(0); err != nil { + logSessionExitError(logger, err) + } } - // Normal completion - logger.Debugf("Pty command completed successfully") - if err := session.Exit(0); err != nil { - logSessionExitError(logger, err) + // Close PTY to unblock io.Copy goroutines + if err := ptyMgr.Close(); err != nil { + logger.Debugf("Pty close after completion: %v", err) } } diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go index 998796871..e1ba777f6 100644 --- a/client/ssh/server/command_execution_windows.go +++ b/client/ssh/server/command_execution_windows.go @@ -20,32 +20,32 @@ import ( // getUserEnvironment retrieves the Windows environment for the target user. // Follows OpenSSH's resilient approach with graceful degradation on failures. -func (s *Server) getUserEnvironment(username, domain string) ([]string, error) { - userToken, err := s.getUserToken(username, domain) +func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) { + userToken, err := s.getUserToken(logger, username, domain) if err != nil { return nil, fmt.Errorf("get user token: %w", err) } defer func() { if err := windows.CloseHandle(userToken); err != nil { - log.Debugf("close user token: %v", err) + logger.Debugf("close user token: %v", err) } }() - return s.getUserEnvironmentWithToken(userToken, username, domain) + return s.getUserEnvironmentWithToken(logger, userToken, username, domain) } // getUserEnvironmentWithToken retrieves the Windows environment using an existing token. -func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) { +func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) { userProfile, err := s.loadUserProfile(userToken, username, domain) if err != nil { - log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err) + logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err) userProfile = fmt.Sprintf("C:\\Users\\%s", username) } envMap := make(map[string]string) if err := s.loadSystemEnvironment(envMap); err != nil { - log.Debugf("failed to load system environment from registry: %v", err) + logger.Debugf("failed to load system environment from registry: %v", err) } s.setUserEnvironmentVariables(envMap, userProfile, username, domain) @@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, } // getUserToken creates a user token for the specified user. -func (s *Server) getUserToken(username, domain string) (windows.Handle, error) { - privilegeDropper := NewPrivilegeDropper() +func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) { + privilegeDropper := NewPrivilegeDropper(WithLogger(logger)) token, err := privilegeDropper.createToken(username, domain) if err != nil { return 0, fmt.Errorf("generate S4U user token: %w", err) @@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi } // prepareCommandEnv prepares environment variables for command execution on Windows -func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { +func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string { username, domain := s.parseUsername(localUser.Username) - userEnv, err := s.getUserEnvironment(username, domain) + userEnv, err := s.getUserEnvironment(logger, username, domain) if err != nil { log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err) env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) @@ -267,22 +267,16 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) [] return env } -func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { +func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool { if privilegeResult.User == nil { logger.Errorf("no user in privilege result") return false } - cmd := session.Command() shell := getUserShell(privilegeResult.User.Uid) + logger.Infof("starting interactive shell: %s", shell) - if len(cmd) == 0 { - logger.Infof("starting interactive shell: %s", shell) - } else { - logger.Infof("executing command: %s", safeLogCommand(cmd)) - } - - s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd) + s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil) return true } @@ -294,11 +288,6 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string { return []string{shell, "-Command", cmdString} } -func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) { - logger.Info("starting interactive shell") - s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand()) -} - type PtyExecutionRequest struct { Shell string Command string @@ -308,25 +297,25 @@ type PtyExecutionRequest struct { Domain string } -func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error { - log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d", +func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error { + logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d", req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height) - privilegeDropper := NewPrivilegeDropper() + privilegeDropper := NewPrivilegeDropper(WithLogger(logger)) userToken, err := privilegeDropper.createToken(req.Username, req.Domain) if err != nil { return fmt.Errorf("create user token: %w", err) } defer func() { if err := windows.CloseHandle(userToken); err != nil { - log.Debugf("close user token: %v", err) + logger.Debugf("close user token: %v", err) } }() server := &Server{} - userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain) + userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain) if err != nil { - log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err) + logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err) userEnv = os.Environ() } @@ -348,8 +337,8 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re Environment: userEnv, } - log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir) - return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig) + logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir) + return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig) } func getUserHomeFromEnv(env []string) string { @@ -371,10 +360,8 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) { return } - logger := log.WithField("pid", cmd.Process.Pid) - if err := cmd.Process.Kill(); err != nil { - logger.Debugf("kill process failed: %v", err) + log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err) } } @@ -389,21 +376,7 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool { } // executeCommandWithPty executes a command with PTY allocation on Windows using ConPty -func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { - command := session.RawCommand() - if command == "" { - logger.Error("no command specified for PTY execution") - if err := session.Exit(1); err != nil { - logSessionExitError(logger, err) - } - return false - } - - return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command) -} - -// executeConPtyCommand executes a command using ConPty (common for interactive and command execution) -func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool { +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool { localUser := privilegeResult.User if localUser == nil { logger.Errorf("no user in privilege result") @@ -415,14 +388,14 @@ func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, pr req := PtyExecutionRequest{ Shell: shell, - Command: command, + Command: session.RawCommand(), Width: ptyReq.Window.Width, Height: ptyReq.Window.Height, Username: username, Domain: domain, } - if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil { + if err := executePtyCommandWithUserToken(logger, session, req); err != nil { logger.Errorf("ConPty execution failed: %v", err) if err := session.Exit(1); err != nil { logSessionExitError(logger, err) diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index 34ffccfd2..7fe2d6c5e 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -4,12 +4,15 @@ import ( "context" "crypto/ed25519" "crypto/rand" + "errors" "fmt" "io" "net" "os" "os/exec" + "path/filepath" "runtime" + "slices" "strings" "testing" "time" @@ -23,25 +26,67 @@ import ( "github.com/netbirdio/netbird/client/ssh/testutil" ) -// TestMain handles package-level setup and cleanup func TestMain(m *testing.M) { - // Guard against infinite recursion when test binary is called as "netbird ssh exec" - // This happens when running tests as non-privileged user with fallback + // On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server + // spawns an executor subprocess via os.Executable(). During tests, this invokes the test + // binary with "ssh exec" args. We handle that here to properly execute commands and + // propagate exit codes. if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" { - // Just exit with error to break the recursion - fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n") - os.Exit(1) + runTestExecutor() + return } - // Run tests code := m.Run() - - // Cleanup any created test users testutil.CleanupTestUsers() - os.Exit(code) } +// runTestExecutor emulates the netbird executor for tests. +// Parses --shell and --cmd args, runs the command, and exits with the correct code. +func runTestExecutor() { + if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" { + fmt.Fprintf(os.Stderr, "executor recursion detected\n") + os.Exit(1) + } + os.Setenv("_NETBIRD_TEST_EXECUTOR", "1") + + shell := "/bin/sh" + var command string + for i := 3; i < len(os.Args); i++ { + switch os.Args[i] { + case "--shell": + if i+1 < len(os.Args) { + shell = os.Args[i+1] + i++ + } + case "--cmd": + if i+1 < len(os.Args) { + command = os.Args[i+1] + i++ + } + } + } + + var cmd *exec.Cmd + if command == "" { + cmd = exec.Command(shell) + } else { + cmd = exec.Command(shell, "-c", command) + } + cmd.Args[0] = "-" + filepath.Base(shell) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + os.Exit(1) + } + os.Exit(0) +} + // TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client func TestSSHServerCompatibility(t *testing.T) { if testing.Short() { @@ -405,6 +450,171 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) { return createTempKeyFileFromBytes(t, privateKey) } +// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags) +// This ensures our implementation matches OpenSSH behavior for: +// - ssh host command (no PTY - default when no TTY) +// - ssh -T host command (explicit no PTY) +// - ssh -t host command (force PTY) +// - ssh -T host (no PTY shell - our implementation) +func TestSSHPtyModes(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH PTY mode tests in short mode") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + if runtime.GOOS == "windows" && testutil.IsCI() { + t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues") + } + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + username := testutil.GetTestUsername(t) + + baseArgs := []string{ + "-i", clientKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-o", "BatchMode=yes", + } + + t.Run("command_default_no_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Command (default no PTY) failed: %s", output) + assert.Contains(t, string(output), "no_pty_default") + }) + + t.Run("command_explicit_no_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output) + assert.Contains(t, string(output), "explicit_no_pty") + }) + + t.Run("command_force_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Command (-tt force PTY) failed: %s", output) + assert.Contains(t, string(output), "force_pty") + }) + + t.Run("shell_explicit_no_pty", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host)) + cmd := exec.CommandContext(ctx, "ssh", args...) + + stdin, err := cmd.StdinPipe() + require.NoError(t, err) + + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + + require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed") + + go func() { + defer stdin.Close() + time.Sleep(100 * time.Millisecond) + _, err := stdin.Write([]byte("echo shell_no_pty_test\n")) + assert.NoError(t, err, "write echo command") + time.Sleep(100 * time.Millisecond) + _, err = stdin.Write([]byte("exit 0\n")) + assert.NoError(t, err, "write exit command") + }() + + output, _ := io.ReadAll(stdout) + err = cmd.Wait() + + require.NoError(t, err, "Shell (-T no PTY) failed: %s", output) + assert.Contains(t, string(output), "shell_no_pty_test") + }) + + t.Run("exit_code_preserved_no_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42") + cmd := exec.Command("ssh", args...) + + err := cmd.Run() + require.Error(t, err, "Command should exit with non-zero") + + var exitErr *exec.ExitError + require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err) + assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T") + }) + + t.Run("exit_code_preserved_with_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'") + cmd := exec.Command("ssh", args...) + + err := cmd.Run() + require.Error(t, err, "PTY command should exit with non-zero") + + var exitErr *exec.ExitError + require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err) + assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt") + }) + + t.Run("stderr_works_no_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), + "sh -c 'echo stdout_msg; echo stderr_msg >&2'") + cmd := exec.Command("ssh", args...) + + var stdout, stderr strings.Builder + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + require.NoError(t, cmd.Run(), "stderr test failed") + assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg") + assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg") + assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg") + }) + + t.Run("stderr_merged_with_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), + "sh -c 'echo stdout_msg; echo stderr_msg >&2'") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + require.NoError(t, err, "PTY stderr test failed: %s", output) + assert.Contains(t, string(output), "stdout_msg") + assert.Contains(t, string(output), "stderr_msg") + }) +} + // TestSSHServerFeatureCompatibility tests specific SSH features for compatibility func TestSSHServerFeatureCompatibility(t *testing.T) { if testing.Short() { diff --git a/client/ssh/server/executor_unix.go b/client/ssh/server/executor_unix.go index 8adc824ef..ee0b0ff78 100644 --- a/client/ssh/server/executor_unix.go +++ b/client/ssh/server/executor_unix.go @@ -8,6 +8,7 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "runtime" "strings" "syscall" @@ -35,11 +36,35 @@ type ExecutorConfig struct { } // PrivilegeDropper handles secure privilege dropping in child processes -type PrivilegeDropper struct{} +type PrivilegeDropper struct { + logger *log.Entry +} + +// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper +type PrivilegeDropperOption func(*PrivilegeDropper) // NewPrivilegeDropper creates a new privilege dropper -func NewPrivilegeDropper() *PrivilegeDropper { - return &PrivilegeDropper{} +func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper { + pd := &PrivilegeDropper{} + for _, opt := range opts { + opt(pd) + } + return pd +} + +// WithLogger sets the logger for the PrivilegeDropper +func WithLogger(logger *log.Entry) PrivilegeDropperOption { + return func(pd *PrivilegeDropper) { + pd.logger = logger + } +} + +// log returns the logger, falling back to standard logger if none set +func (pd *PrivilegeDropper) log() *log.Entry { + if pd.logger != nil { + return pd.logger + } + return log.NewEntry(log.StandardLogger()) } // CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping @@ -83,7 +108,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex break } } - log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs) + pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs) return exec.CommandContext(ctx, netbirdPath, args...), nil } @@ -206,17 +231,22 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config var execCmd *exec.Cmd if config.Command == "" { - os.Exit(ExitCodeSuccess) + execCmd = exec.CommandContext(ctx, config.Shell) + } else { + execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command) } - - execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command) + execCmd.Args[0] = "-" + filepath.Base(config.Shell) execCmd.Stdin = os.Stdin execCmd.Stdout = os.Stdout execCmd.Stderr = os.Stderr - cmdParts := strings.Fields(config.Command) - safeCmd := safeLogCommand(cmdParts) - log.Tracef("executing %s -c %s", execCmd.Path, safeCmd) + if config.Command == "" { + log.Tracef("executing login shell: %s", execCmd.Path) + } else { + cmdParts := strings.Fields(config.Command) + safeCmd := safeLogCommand(cmdParts) + log.Tracef("executing %s -c %s", execCmd.Path, safeCmd) + } if err := execCmd.Run(); err != nil { var exitError *exec.ExitError if errors.As(err, &exitError) { diff --git a/client/ssh/server/executor_windows.go b/client/ssh/server/executor_windows.go index d3504e056..51c995ec3 100644 --- a/client/ssh/server/executor_windows.go +++ b/client/ssh/server/executor_windows.go @@ -28,22 +28,45 @@ const ( ) type WindowsExecutorConfig struct { - Username string - Domain string - WorkingDir string - Shell string - Command string - Args []string - Interactive bool - Pty bool - PtyWidth int - PtyHeight int + Username string + Domain string + WorkingDir string + Shell string + Command string + Args []string + Pty bool + PtyWidth int + PtyHeight int } -type PrivilegeDropper struct{} +type PrivilegeDropper struct { + logger *log.Entry +} -func NewPrivilegeDropper() *PrivilegeDropper { - return &PrivilegeDropper{} +// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper +type PrivilegeDropperOption func(*PrivilegeDropper) + +func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper { + pd := &PrivilegeDropper{} + for _, opt := range opts { + opt(pd) + } + return pd +} + +// WithLogger sets the logger for the PrivilegeDropper +func WithLogger(logger *log.Entry) PrivilegeDropperOption { + return func(pd *PrivilegeDropper) { + pd.logger = logger + } +} + +// log returns the logger, falling back to standard logger if none set +func (pd *PrivilegeDropper) log() *log.Entry { + if pd.logger != nil { + return pd.logger + } + return log.NewEntry(log.StandardLogger()) } var ( @@ -56,7 +79,6 @@ const ( // Common error messages commandFlag = "-Command" - closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials convertUsernameError = "convert username to UTF16: %w" convertDomainError = "convert domain to UTF16: %w" ) @@ -80,7 +102,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co shellArgs = []string{shell} } - log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs) + pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs) cmd, token, err := pd.CreateWindowsProcessAsUser( ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir) @@ -180,10 +202,10 @@ func newLsaString(s string) lsaString { // generateS4UUserToken creates a Windows token using S4U authentication // This is the exact approach OpenSSH for Windows uses for public key authentication -func generateS4UUserToken(username, domain string) (windows.Handle, error) { +func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) { userCpn := buildUserCpn(username, domain) - pd := NewPrivilegeDropper() + pd := NewPrivilegeDropper(WithLogger(logger)) isDomainUser := !pd.isLocalUser(domain) lsaHandle, err := initializeLsaConnection() @@ -197,12 +219,12 @@ func generateS4UUserToken(username, domain string) (windows.Handle, error) { return 0, err } - logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser) + logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser) if err != nil { return 0, err } - return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser) + return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser) } // buildUserCpn constructs the user principal name @@ -310,21 +332,21 @@ func lookupPrincipalName(username, domain string) (string, error) { } // prepareS4ULogonStructure creates the appropriate S4U logon structure -func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) { +func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) { if isDomainUser { - return prepareDomainS4ULogon(username, domain) + return prepareDomainS4ULogon(logger, username, domain) } - return prepareLocalS4ULogon(username) + return prepareLocalS4ULogon(logger, username) } // prepareDomainS4ULogon creates S4U logon structure for domain users -func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) { +func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) { upn, err := lookupPrincipalName(username, domain) if err != nil { return nil, 0, fmt.Errorf("lookup principal name: %w", err) } - log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn) + logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn) upnUtf16, err := windows.UTF16FromString(upn) if err != nil { @@ -357,8 +379,8 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er } // prepareLocalS4ULogon creates S4U logon structure for local users -func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) { - log.Debugf("using Msv1_0S4ULogon for local user: %s", username) +func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) { + logger.Debugf("using Msv1_0S4ULogon for local user: %s", username) usernameUtf16, err := windows.UTF16FromString(username) if err != nil { @@ -406,11 +428,11 @@ func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) { } // performS4ULogon executes the S4U logon operation -func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) { +func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) { var tokenSource tokenSource copy(tokenSource.SourceName[:], "netbird") if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 { - log.Debugf("AllocateLocallyUniqueId failed") + logger.Debugf("AllocateLocallyUniqueId failed") } originName := newLsaString("netbird") @@ -441,7 +463,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u if profile != 0 { if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess { - log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret) + logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret) } } @@ -449,7 +471,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus) } - log.Debugf("created S4U %s token for user %s", + logger.Debugf("created S4U %s token for user %s", map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn) return token, nil } @@ -497,8 +519,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool { // authenticateLocalUser handles authentication for local users func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) { - log.Debugf("using S4U authentication for local user %s", fullUsername) - token, err := generateS4UUserToken(username, ".") + pd.log().Debugf("using S4U authentication for local user %s", fullUsername) + token, err := generateS4UUserToken(pd.log(), username, ".") if err != nil { return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err) } @@ -507,12 +529,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) // authenticateDomainUser handles authentication for domain users func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) { - log.Debugf("using S4U authentication for domain user %s", fullUsername) - token, err := generateS4UUserToken(username, domain) + pd.log().Debugf("using S4U authentication for domain user %s", fullUsername) + token, err := generateS4UUserToken(pd.log(), username, domain) if err != nil { return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err) } - log.Debugf("Successfully created S4U token for domain user %s", fullUsername) + pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername) return token, nil } @@ -526,7 +548,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec defer func() { if err := windows.CloseHandle(token); err != nil { - log.Debugf("close impersonation token: %v", err) + pd.log().Debugf("close impersonation token: %v", err) } }() @@ -564,7 +586,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo return cmd, primaryToken, nil } -// createSuCommand creates a command using su -l -c for privilege switching (Windows stub) -func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) { +// createSuCommand creates a command using su - for privilege switching (Windows stub). +func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) { return nil, fmt.Errorf("su command not available on Windows") } diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go index dbef011ac..b2f3ac6a0 100644 --- a/client/ssh/server/jwt_test.go +++ b/client/ssh/server/jwt_test.go @@ -54,7 +54,7 @@ func TestJWTEnforcement(t *testing.T) { server.SetAllowRootLogin(true) serverAddr := StartTestServer(t, server) - defer require.NoError(t, server.Stop()) + defer func() { require.NoError(t, server.Stop()) }() host, portStr, err := net.SplitHostPort(serverAddr) require.NoError(t, err) @@ -88,7 +88,7 @@ func TestJWTEnforcement(t *testing.T) { serverNoJWT.SetAllowRootLogin(true) serverAddrNoJWT := StartTestServer(t, serverNoJWT) - defer require.NoError(t, serverNoJWT.Stop()) + defer func() { require.NoError(t, serverNoJWT.Stop()) }() hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT) require.NoError(t, err) @@ -213,7 +213,7 @@ func TestJWTDetection(t *testing.T) { server.SetAllowRootLogin(true) serverAddr := StartTestServer(t, server) - defer require.NoError(t, server.Stop()) + defer func() { require.NoError(t, server.Stop()) }() host, portStr, err := net.SplitHostPort(serverAddr) require.NoError(t, err) @@ -341,7 +341,7 @@ func TestJWTFailClose(t *testing.T) { server.SetAllowRootLogin(true) serverAddr := StartTestServer(t, server) - defer require.NoError(t, server.Stop()) + defer func() { require.NoError(t, server.Stop()) }() host, portStr, err := net.SplitHostPort(serverAddr) require.NoError(t, err) @@ -596,7 +596,7 @@ func TestJWTAuthentication(t *testing.T) { server.UpdateSSHAuth(authConfig) serverAddr := StartTestServer(t, server) - defer require.NoError(t, server.Stop()) + defer func() { require.NoError(t, server.Stop()) }() host, portStr, err := net.SplitHostPort(serverAddr) require.NoError(t, err) @@ -715,7 +715,7 @@ func TestJWTMultipleAudiences(t *testing.T) { server.UpdateSSHAuth(authConfig) serverAddr := StartTestServer(t, server) - defer require.NoError(t, server.Stop()) + defer func() { require.NoError(t, server.Stop()) }() host, portStr, err := net.SplitHostPort(serverAddr) require.NoError(t, err) diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index c60cf4f58..e16ff5d46 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -271,13 +271,6 @@ func (s *Server) isRemotePortForwardingAllowed() bool { return s.allowRemotePortForwarding } -// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled -func (s *Server) isPortForwardingEnabled() bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.allowLocalPortForwarding || s.allowRemotePortForwarding -} - // parseTcpipForwardRequest parses the SSH request payload func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) { var payload tcpipForwardMsg diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index e897bbade..1ddb60f8e 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) { sessions = append(sessions, info) } - // Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only) + // Add authenticated connections without sessions (e.g., -N or port-forwarding only) for key, connState := range s.connections { remoteAddr := string(key) if reportedAddrs[remoteAddr] { diff --git a/client/ssh/server/server_config_test.go b/client/ssh/server/server_config_test.go index d85d85a51..f70e29963 100644 --- a/client/ssh/server/server_config_test.go +++ b/client/ssh/server/server_config_test.go @@ -483,12 +483,11 @@ func TestServer_IsPrivilegedUser(t *testing.T) { } } -func TestServer_PortForwardingOnlySession(t *testing.T) { - // Test that sessions without PTY and command are allowed when port forwarding is enabled +func TestServer_NonPtyShellSession(t *testing.T) { + // Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings. currentUser, err := user.Current() require.NoError(t, err, "Should be able to get current user") - // Generate host key for server hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) require.NoError(t, err) @@ -496,36 +495,26 @@ func TestServer_PortForwardingOnlySession(t *testing.T) { name string allowLocalForwarding bool allowRemoteForwarding bool - expectAllowed bool - description string }{ { - name: "session_allowed_with_local_forwarding", + name: "shell_with_local_forwarding_enabled", allowLocalForwarding: true, allowRemoteForwarding: false, - expectAllowed: true, - description: "Port-forwarding-only session should be allowed when local forwarding is enabled", }, { - name: "session_allowed_with_remote_forwarding", + name: "shell_with_remote_forwarding_enabled", allowLocalForwarding: false, allowRemoteForwarding: true, - expectAllowed: true, - description: "Port-forwarding-only session should be allowed when remote forwarding is enabled", }, { - name: "session_allowed_with_both", + name: "shell_with_both_forwarding_enabled", allowLocalForwarding: true, allowRemoteForwarding: true, - expectAllowed: true, - description: "Port-forwarding-only session should be allowed when both forwarding types enabled", }, { - name: "session_denied_without_forwarding", + name: "shell_with_forwarding_disabled", allowLocalForwarding: false, allowRemoteForwarding: false, - expectAllowed: false, - description: "Port-forwarding-only session should be denied when all forwarding is disabled", }, } @@ -545,7 +534,6 @@ func TestServer_PortForwardingOnlySession(t *testing.T) { _ = server.Stop() }() - // Connect to the server without requesting PTY or command ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -557,20 +545,10 @@ func TestServer_PortForwardingOnlySession(t *testing.T) { _ = client.Close() }() - // Execute a command without PTY - this simulates ssh -T with no command - // The server should either allow it (port forwarding enabled) or reject it - output, err := client.ExecuteCommand(ctx, "") - if tt.expectAllowed { - // When allowed, the session stays open until cancelled - // ExecuteCommand with empty command should return without error - assert.NoError(t, err, "Session should be allowed when port forwarding is enabled") - assert.NotContains(t, output, "port forwarding is disabled", - "Output should not contain port forwarding disabled message") - } else if err != nil { - // When denied, we expect an error message about port forwarding being disabled - assert.Contains(t, err.Error(), "port forwarding is disabled", - "Should get port forwarding disabled message") - } + // Execute without PTY and no command - simulates ssh -T (shell without PTY) + // Should always succeed regardless of port forwarding settings + _, err = client.ExecuteCommand(ctx, "") + assert.NoError(t, err, "Non-PTY shell session should be allowed") }) } } diff --git a/client/ssh/server/server_test.go b/client/ssh/server/server_test.go index 661068539..89fab717f 100644 --- a/client/ssh/server/server_test.go +++ b/client/ssh/server/server_test.go @@ -405,12 +405,14 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) { assert.Equal(t, "-Command", args[1]) assert.Equal(t, "echo test", args[2]) } else { - // Test Unix shell behavior args := server.getShellCommandArgs("/bin/sh", "echo test") assert.Equal(t, "/bin/sh", args[0]) - assert.Equal(t, "-l", args[1]) - assert.Equal(t, "-c", args[2]) - assert.Equal(t, "echo test", args[3]) + assert.Equal(t, "-c", args[1]) + assert.Equal(t, "echo test", args[2]) + + args = server.getShellCommandArgs("/bin/sh", "") + assert.Equal(t, "/bin/sh", args[0]) + assert.Len(t, args, 1) } } diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index 3fd578064..f12a75961 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -62,54 +62,12 @@ func (s *Server) sessionHandler(session ssh.Session) { ptyReq, winCh, isPty := session.Pty() hasCommand := len(session.Command()) > 0 - switch { - case isPty && hasCommand: - // ssh -t - Pty command execution - s.handleCommand(logger, session, privilegeResult, winCh) - case isPty: - // ssh - Pty interactive session (login) - s.handlePty(logger, session, privilegeResult, ptyReq, winCh) - case hasCommand: - // ssh - non-Pty command execution - s.handleCommand(logger, session, privilegeResult, nil) - default: - // ssh -T (or ssh -N) - no PTY, no command - s.handleNonInteractiveSession(logger, session) - } -} - -// handleNonInteractiveSession handles sessions that have no PTY and no command. -// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N). -func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) { - s.updateSessionType(session, cmdNonInteractive) - - if !s.isPortForwardingEnabled() { - if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil { - logger.Debugf(errWriteSession, err) - } - if err := session.Exit(1); err != nil { - logSessionExitError(logger, err) - } - logger.Infof("rejected non-interactive session: port forwarding disabled") - return - } - - <-session.Context().Done() - - if err := session.Exit(0); err != nil { - logSessionExitError(logger, err) - } -} - -func (s *Server) updateSessionType(session ssh.Session, sessionType string) { - s.mu.Lock() - defer s.mu.Unlock() - - for _, state := range s.sessions { - if state.session == session { - state.sessionType = sessionType - return - } + if isPty && !hasCommand { + // ssh - PTY interactive session (login) + s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh) + } else { + // ssh , ssh -t , ssh -T - command or shell execution + s.handleExecution(logger, session, privilegeResult, ptyReq, winCh) } } diff --git a/client/ssh/server/session_handlers_js.go b/client/ssh/server/session_handlers_js.go index c35e4da0b..4a6cf3d92 100644 --- a/client/ssh/server/session_handlers_js.go +++ b/client/ssh/server/session_handlers_js.go @@ -9,8 +9,8 @@ import ( log "github.com/sirupsen/logrus" ) -// handlePty is not supported on JS/WASM -func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool { +// handlePtyLogin is not supported on JS/WASM +func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool { errorMsg := "PTY sessions are not supported on WASM/JS platform\n" if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil { logger.Debugf(errWriteSession, err) diff --git a/client/ssh/server/test.go b/client/ssh/server/test.go index f8abd1752..454d3afa3 100644 --- a/client/ssh/server/test.go +++ b/client/ssh/server/test.go @@ -8,19 +8,18 @@ import ( "time" ) +// StartTestServer starts the SSH server and returns the address it's listening on. func StartTestServer(t *testing.T, server *Server) string { started := make(chan string, 1) errChan := make(chan error, 1) go func() { - // Use port 0 to let the OS assign a free port addrPort := netip.MustParseAddrPort("127.0.0.1:0") if err := server.Start(context.Background(), addrPort); err != nil { errChan <- err return } - // Get the actual listening address from the server actualAddr := server.Addr() if actualAddr == nil { errChan <- fmt.Errorf("server started but no listener address available") diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go index bc1557419..d80b77042 100644 --- a/client/ssh/server/userswitching_unix.go +++ b/client/ssh/server/userswitching_unix.go @@ -181,8 +181,8 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) { // createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping. // Returns the command and a cleanup function (no-op on Unix). -func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { - log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty) +func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + logger.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty) if err := validateUsername(localUser.Username); err != nil { return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) @@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User if err != nil { return nil, nil, fmt.Errorf("parse user credentials: %w", err) } - privilegeDropper := NewPrivilegeDropper() + privilegeDropper := NewPrivilegeDropper(WithLogger(logger)) config := ExecutorConfig{ UID: uid, GID: gid, @@ -233,7 +233,7 @@ func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.Use shell := getUserShell(localUser.Uid) args := s.getShellCommandArgs(shell, session.RawCommand()) - cmd := exec.CommandContext(session.Context(), args[0], args[1:]...) + cmd := s.createShellCommand(session.Context(), shell, args) cmd.Dir = localUser.HomeDir cmd.Env = s.preparePtyEnv(localUser, ptyReq, session) diff --git a/client/ssh/server/userswitching_windows.go b/client/ssh/server/userswitching_windows.go index 5a5f75fa4..260e1301e 100644 --- a/client/ssh/server/userswitching_windows.go +++ b/client/ssh/server/userswitching_windows.go @@ -88,20 +88,20 @@ func validateUsernameFormat(username string) error { // createExecutorCommand creates a command using Windows executor for privilege dropping. // Returns the command and a cleanup function that must be called after starting the process. -func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { - log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty) +func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + logger.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty) username, _ := s.parseUsername(localUser.Username) if err := validateUsername(username); err != nil { return nil, nil, fmt.Errorf("invalid username %q: %w", username, err) } - return s.createUserSwitchCommand(localUser, session, hasPty) + return s.createUserSwitchCommand(logger, session, localUser) } // createUserSwitchCommand creates a command with Windows user switching. // Returns the command and a cleanup function that must be called after starting the process. -func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) { +func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) { username, domain := s.parseUsername(localUser.Username) shell := getUserShell(localUser.Uid) @@ -113,15 +113,14 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi } config := WindowsExecutorConfig{ - Username: username, - Domain: domain, - WorkingDir: localUser.HomeDir, - Shell: shell, - Command: command, - Interactive: interactive || (rawCmd == ""), + Username: username, + Domain: domain, + WorkingDir: localUser.HomeDir, + Shell: shell, + Command: command, } - dropper := NewPrivilegeDropper() + dropper := NewPrivilegeDropper(WithLogger(logger)) cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config) if err != nil { return nil, nil, err @@ -130,7 +129,7 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi cleanup := func() { if token != 0 { if err := windows.CloseHandle(windows.Handle(token)); err != nil { - log.Debugf("close primary token: %v", err) + logger.Debugf("close primary token: %v", err) } } } diff --git a/client/ssh/server/winpty/conpty.go b/client/ssh/server/winpty/conpty.go index 0f3659ffe..c08ccfd05 100644 --- a/client/ssh/server/winpty/conpty.go +++ b/client/ssh/server/winpty/conpty.go @@ -56,7 +56,7 @@ var ( ) // ExecutePtyWithUserToken executes a command with ConPty using user token. -func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error { +func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error { args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command) commandLine := buildCommandLine(args) @@ -64,7 +64,7 @@ func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig Pty: ptyConfig, User: userConfig, Session: session, - Context: ctx, + Context: session.Context(), } return executeConPtyWithConfig(commandLine, config) diff --git a/client/status/status.go b/client/status/status.go index be28ff67d..f13163a41 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -491,6 +491,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total) + var forwardingRulesString string + if o.NumberOfForwardingRules > 0 { + forwardingRulesString = fmt.Sprintf("Forwarding rules: %d\n", o.NumberOfForwardingRules) + } + goos := runtime.GOOS goarch := runtime.GOARCH goarm := "" @@ -514,7 +519,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS "Lazy connection: %s\n"+ "SSH Server: %s\n"+ "Networks: %s\n"+ - "Forwarding rules: %d\n"+ + "%s"+ "Peers count: %s\n", fmt.Sprintf("%s/%s%s", goos, goarch, goarm), o.DaemonVersion, @@ -531,7 +536,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS lazyConnectionEnabledStatus, sshServerStatus, networks, - o.NumberOfForwardingRules, + forwardingRulesString, peersCountString, ) return summary diff --git a/client/status/status_test.go b/client/status/status_test.go index ad158722b..b02d78d64 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -567,7 +567,6 @@ Quantum resistance: false Lazy connection: false SSH Server: Disabled Networks: 10.10.0.0/24 -Forwarding rules: 0 Peers count: 2/2 Connected `, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion) @@ -592,7 +591,6 @@ Quantum resistance: false Lazy connection: false SSH Server: Disabled Networks: 10.10.0.0/24 -Forwarding rules: 0 Peers count: 2/2 Connected ` diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 5d955ed25..0290e17d5 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -1033,7 +1033,7 @@ func (s *serviceClient) onTrayReady() { s.mDown.Disable() systray.AddSeparator() - s.mSettings = systray.AddMenuItem("Settings", settingsMenuDescr) + s.mSettings = systray.AddMenuItem("Settings", disabledMenuDescr) s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false) s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false) s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false) @@ -1060,7 +1060,7 @@ func (s *serviceClient) onTrayReady() { } s.exitNodeMu.Lock() - s.mExitNode = systray.AddMenuItem("Exit Node", exitNodeMenuDescr) + s.mExitNode = systray.AddMenuItem("Exit Node", disabledMenuDescr) s.mExitNode.Disable() s.exitNodeMu.Unlock() @@ -1261,7 +1261,6 @@ func (s *serviceClient) setSettingsEnabled(enabled bool) { if s.mSettings != nil { if enabled { s.mSettings.Enable() - s.mSettings.SetTooltip(settingsMenuDescr) } else { s.mSettings.Hide() s.mSettings.SetTooltip("Settings are disabled by daemon") diff --git a/client/ui/const.go b/client/ui/const.go index 332282c17..48619be75 100644 --- a/client/ui/const.go +++ b/client/ui/const.go @@ -1,8 +1,6 @@ package main const ( - settingsMenuDescr = "Settings of the application" - profilesMenuDescr = "Manage your profiles" allowSSHMenuDescr = "Allow SSH connections" autoConnectMenuDescr = "Connect automatically when the service starts" quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass" @@ -11,7 +9,7 @@ const ( notificationsMenuDescr = "Enable notifications" advancedSettingsMenuDescr = "Advanced settings of the application" debugBundleMenuDescr = "Create and open debug information bundle" - exitNodeMenuDescr = "Select exit node for routing traffic" + disabledMenuDescr = "" networksMenuDescr = "Open the networks management window" latestVersionMenuDescr = "Download latest version" quitMenuDescr = "Quit the client app" diff --git a/client/ui/debug.go b/client/ui/debug.go index e9bcfde41..29f73a66a 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -406,6 +406,10 @@ func (s *serviceClient) configureServiceForDebug( } time.Sleep(time.Second * 3) + if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil { + log.Warnf("failed to start CPU profiling: %v", err) + } + return nil } @@ -428,6 +432,10 @@ func (s *serviceClient) collectDebugData( progress.progressBar.Hide() progress.statusLabel.SetText("Collecting debug data...") + if _, err := conn.StopCPUProfile(s.ctx, &proto.StopCPUProfileRequest{}); err != nil { + log.Warnf("failed to stop CPU profiling: %v", err) + } + return nil } diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index 9ffacd926..2216c8aeb 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -63,6 +63,8 @@ func (h *eventHandler) listen(ctx context.Context) { h.handleNetworksClick() case <-h.client.mNotifications.ClickedCh: h.handleNotificationsClick() + case <-systray.TrayOpenedCh: + h.client.updateExitNodes() } } } @@ -99,6 +101,8 @@ func (h *eventHandler) handleConnectClick() { func (h *eventHandler) handleDisconnectClick() { h.client.mDown.Disable() + h.client.exitNodeStates = []exitNodeState{} + if h.client.connectCancel != nil { log.Debugf("cancelling ongoing connect operation") h.client.connectCancel() diff --git a/client/ui/network.go b/client/ui/network.go index fb73efd7b..9a5ad7662 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -341,7 +341,6 @@ func (s *serviceClient) updateExitNodes() { log.Errorf("get client: %v", err) return } - exitNodes, err := s.getExitNodes(conn) if err != nil { log.Errorf("get exit nodes: %v", err) @@ -390,7 +389,7 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { if runtime.GOOS == "linux" || runtime.GOOS == "freebsd" { s.mExitNode.Remove() - s.mExitNode = systray.AddMenuItem("Exit Node", exitNodeMenuDescr) + s.mExitNode = systray.AddMenuItem("Exit Node", disabledMenuDescr) } var showDeselectAll bool diff --git a/go.mod b/go.mod index 80999ca8a..dc922b2f8 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( require ( fyne.io/fyne/v2 v2.7.0 - fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 + fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/awnumar/memguard v0.23.0 github.com/aws/aws-sdk-go-v2 v1.36.3 @@ -68,7 +68,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 + github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/oapi-codegen/runtime v1.1.2 github.com/okta/okta-sdk-golang/v2 v2.18.0 diff --git a/go.sum b/go.sum index e44b8b122..da3eb4a0e 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ= fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE= -fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 h1:eA5/u2XRd8OUkoMqEv3IBlFYSruNlXD8bRHDiqm0VNI= -fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= +fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4= +fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk= @@ -408,8 +408,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index d46737c26..5ae64e9f1 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -856,3 +856,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) { c.peersUpdateManager.CloseChannels(ctx, peerIDs) } + +func (c *Controller) TrackEphemeralPeer(ctx context.Context, peer *nbpeer.Peer) { + c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer) +} diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go index b1de7d017..64caac861 100644 --- a/management/internals/controllers/network_map/interface.go +++ b/management/internals/controllers/network_map/interface.go @@ -36,4 +36,6 @@ type Controller interface { DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) + + TrackEphemeralPeer(ctx context.Context, peer *nbpeer.Peer) } diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go index 5a98eefa8..4e86d2973 100644 --- a/management/internals/controllers/network_map/interface_mock.go +++ b/management/internals/controllers/network_map/interface_mock.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: ./interface.go +// Source: management/internals/controllers/network_map/interface.go // // Generated by this command: // -// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod +// mockgen -package network_map -destination=management/internals/controllers/network_map/interface_mock.go -source=management/internals/controllers/network_map/interface.go -build_flags=-mod=mod // // Package network_map is a generated GoMock package. @@ -211,6 +211,18 @@ func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0) } +// TrackEphemeralPeer mocks base method. +func (m *MockController) TrackEphemeralPeer(ctx context.Context, arg1 *peer.Peer) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "TrackEphemeralPeer", ctx, arg1) +} + +// TrackEphemeralPeer indicates an expected call of TrackEphemeralPeer. +func (mr *MockControllerMockRecorder) TrackEphemeralPeer(ctx, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TrackEphemeralPeer", reflect.TypeOf((*MockController)(nil).TrackEphemeralPeer), ctx, arg1) +} + // UpdateAccountPeer mocks base method. func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error { m.ctrl.T.Helper() diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 1ff0243f4..32049d044 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -232,6 +232,9 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S userID, err := s.accountManager.GetUserIDByPeerKey(ctx, peerKey.String()) if err != nil { s.syncSem.Add(-1) + if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound { + return status.Errorf(codes.PermissionDenied, "peer is not registered") + } return mapError(ctx, err) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 11af67358..5e9bb42a2 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -30,6 +30,12 @@ type Manager interface { autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error) + CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) + AcceptUserInvite(ctx context.Context, token, password string) error + RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) + GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) + ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) + DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index e9eaa644b..e83eeb90a 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -199,6 +199,11 @@ const ( UserPasswordChanged Activity = 103 + UserInviteLinkCreated Activity = 104 + UserInviteLinkAccepted Activity = 105 + UserInviteLinkRegenerated Activity = 106 + UserInviteLinkDeleted Activity = 107 + AccountDeleted Activity = 99999 ) @@ -327,6 +332,11 @@ var activityMap = map[Activity]Code{ JobCreatedByUser: {"Create Job for peer", "peer.job.create"}, UserPasswordChanged: {"User password changed", "user.password.change"}, + + UserInviteLinkCreated: {"User invite link created", "user.invite.link.create"}, + UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"}, + UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"}, + UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"}, } // StringCode returns a string code of the activity diff --git a/management/server/activity/store/crypt.go b/management/server/activity/store/crypt.go deleted file mode 100644 index ce97347d4..000000000 --- a/management/server/activity/store/crypt.go +++ /dev/null @@ -1,136 +0,0 @@ -package store - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "encoding/base64" - "errors" -) - -var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} - -type FieldEncrypt struct { - block cipher.Block - gcm cipher.AEAD -} - -func GenerateKey() (string, error) { - key := make([]byte, 32) - _, err := rand.Read(key) - if err != nil { - return "", err - } - readableKey := base64.StdEncoding.EncodeToString(key) - return readableKey, nil -} - -func NewFieldEncrypt(key string) (*FieldEncrypt, error) { - binKey, err := base64.StdEncoding.DecodeString(key) - if err != nil { - return nil, err - } - - block, err := aes.NewCipher(binKey) - if err != nil { - return nil, err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - ec := &FieldEncrypt{ - block: block, - gcm: gcm, - } - - return ec, nil -} - -func (ec *FieldEncrypt) LegacyEncrypt(payload string) string { - plainText := pkcs5Padding([]byte(payload)) - cipherText := make([]byte, len(plainText)) - cbc := cipher.NewCBCEncrypter(ec.block, iv) - cbc.CryptBlocks(cipherText, plainText) - return base64.StdEncoding.EncodeToString(cipherText) -} - -// Encrypt encrypts plaintext using AES-GCM -func (ec *FieldEncrypt) Encrypt(payload string) (string, error) { - plaintext := []byte(payload) - nonceSize := ec.gcm.NonceSize() - - nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead()) - if _, err := rand.Read(nonce); err != nil { - return "", err - } - - ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil) - - return base64.StdEncoding.EncodeToString(ciphertext), nil -} - -func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) { - cipherText, err := base64.StdEncoding.DecodeString(data) - if err != nil { - return "", err - } - cbc := cipher.NewCBCDecrypter(ec.block, iv) - cbc.CryptBlocks(cipherText, cipherText) - payload, err := pkcs5UnPadding(cipherText) - if err != nil { - return "", err - } - - return string(payload), nil -} - -// Decrypt decrypts ciphertext using AES-GCM -func (ec *FieldEncrypt) Decrypt(data string) (string, error) { - cipherText, err := base64.StdEncoding.DecodeString(data) - if err != nil { - return "", err - } - - nonceSize := ec.gcm.NonceSize() - if len(cipherText) < nonceSize { - return "", errors.New("cipher text too short") - } - - nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:] - plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil) - if err != nil { - return "", err - } - - return string(plainText), nil -} - -func pkcs5Padding(ciphertext []byte) []byte { - padding := aes.BlockSize - len(ciphertext)%aes.BlockSize - padText := bytes.Repeat([]byte{byte(padding)}, padding) - return append(ciphertext, padText...) -} -func pkcs5UnPadding(src []byte) ([]byte, error) { - srcLen := len(src) - if srcLen == 0 { - return nil, errors.New("input data is empty") - } - - paddingLen := int(src[srcLen-1]) - if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen { - return nil, errors.New("invalid padding size") - } - - // Verify that all padding bytes are the same - for i := 0; i < paddingLen; i++ { - if src[srcLen-1-i] != byte(paddingLen) { - return nil, errors.New("invalid padding") - } - } - - return src[:srcLen-paddingLen], nil -} diff --git a/management/server/activity/store/crypt_test.go b/management/server/activity/store/crypt_test.go deleted file mode 100644 index 700bbcd6b..000000000 --- a/management/server/activity/store/crypt_test.go +++ /dev/null @@ -1,310 +0,0 @@ -package store - -import ( - "bytes" - "testing" -) - -func TestGenerateKey(t *testing.T) { - testData := "exampl@netbird.io" - key, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - ee, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - encrypted, err := ee.Encrypt(testData) - if err != nil { - t.Fatalf("failed to encrypt data: %s", err) - } - - if encrypted == "" { - t.Fatalf("invalid encrypted text") - } - - decrypted, err := ee.Decrypt(encrypted) - if err != nil { - t.Fatalf("failed to decrypt data: %s", err) - } - - if decrypted != testData { - t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted) - } -} - -func TestGenerateKeyLegacy(t *testing.T) { - testData := "exampl@netbird.io" - key, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - ee, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - encrypted := ee.LegacyEncrypt(testData) - if encrypted == "" { - t.Fatalf("invalid encrypted text") - } - - decrypted, err := ee.LegacyDecrypt(encrypted) - if err != nil { - t.Fatalf("failed to decrypt data: %s", err) - } - - if decrypted != testData { - t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted) - } -} - -func TestCorruptKey(t *testing.T) { - testData := "exampl@netbird.io" - key, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - ee, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - encrypted, err := ee.Encrypt(testData) - if err != nil { - t.Fatalf("failed to encrypt data: %s", err) - } - - if encrypted == "" { - t.Fatalf("invalid encrypted text") - } - - newKey, err := GenerateKey() - if err != nil { - t.Fatalf("failed to generate key: %s", err) - } - - ee, err = NewFieldEncrypt(newKey) - if err != nil { - t.Fatalf("failed to init email encryption: %s", err) - } - - res, _ := ee.Decrypt(encrypted) - if res == testData { - t.Fatalf("incorrect decryption, the result is: %s", res) - } -} - -func TestEncryptDecrypt(t *testing.T) { - // Generate a key for encryption/decryption - key, err := GenerateKey() - if err != nil { - t.Fatalf("Failed to generate key: %v", err) - } - - // Initialize the FieldEncrypt with the generated key - ec, err := NewFieldEncrypt(key) - if err != nil { - t.Fatalf("Failed to create FieldEncrypt: %v", err) - } - - // Test cases - testCases := []struct { - name string - input string - }{ - { - name: "Empty String", - input: "", - }, - { - name: "Short String", - input: "Hello", - }, - { - name: "String with Spaces", - input: "Hello, World!", - }, - { - name: "Long String", - input: "The quick brown fox jumps over the lazy dog.", - }, - { - name: "Unicode Characters", - input: "こんにちは世界", - }, - { - name: "Special Characters", - input: "!@#$%^&*()_+-=[]{}|;':\",./<>?", - }, - { - name: "Numeric String", - input: "1234567890", - }, - { - name: "Repeated Characters", - input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", - }, - { - name: "Multi-block String", - input: "This is a longer string that will span multiple blocks in the encryption algorithm.", - }, - { - name: "Non-ASCII and ASCII Mix", - input: "Hello 世界 123", - }, - } - - for _, tc := range testCases { - t.Run(tc.name+" - Legacy", func(t *testing.T) { - // Legacy Encryption - encryptedLegacy := ec.LegacyEncrypt(tc.input) - if encryptedLegacy == "" { - t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input) - } - - // Legacy Decryption - decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy) - if err != nil { - t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err) - } - - // Verify that the decrypted value matches the original input - if decryptedLegacy != tc.input { - t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input) - } - }) - - t.Run(tc.name+" - New", func(t *testing.T) { - // New Encryption - encryptedNew, err := ec.Encrypt(tc.input) - if err != nil { - t.Errorf("Encrypt failed for input '%s': %v", tc.input, err) - } - if encryptedNew == "" { - t.Errorf("Encrypt returned empty string for input '%s'", tc.input) - } - - // New Decryption - decryptedNew, err := ec.Decrypt(encryptedNew) - if err != nil { - t.Errorf("Decrypt failed for input '%s': %v", tc.input, err) - } - - // Verify that the decrypted value matches the original input - if decryptedNew != tc.input { - t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input) - } - }) - } -} - -func TestPKCS5UnPadding(t *testing.T) { - tests := []struct { - name string - input []byte - expected []byte - expectError bool - }{ - { - name: "Valid Padding", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...), - expected: []byte("Hello, World!"), - }, - { - name: "Empty Input", - input: []byte{}, - expectError: true, - }, - { - name: "Padding Length Zero", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...), - expectError: true, - }, - { - name: "Padding Length Exceeds Block Size", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...), - expectError: true, - }, - { - name: "Padding Length Exceeds Input Length", - input: []byte{5, 5, 5}, - expectError: true, - }, - { - name: "Invalid Padding Bytes", - input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...), - expectError: true, - }, - { - name: "Valid Single Byte Padding", - input: append([]byte("Hello, World!"), byte(1)), - expected: []byte("Hello, World!"), - }, - { - name: "Invalid Mixed Padding Bytes", - input: append([]byte("Hello, World!"), []byte{3, 3, 2}...), - expectError: true, - }, - { - name: "Valid Full Block Padding", - input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...), - expected: []byte("Hello, World!"), - }, - { - name: "Non-Padding Byte at End", - input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...), - expectError: true, - }, - { - name: "Valid Padding with Different Text Length", - input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...), - expected: []byte("Test"), - }, - { - name: "Padding Length Equal to Input Length", - input: bytes.Repeat([]byte{8}, 8), - expected: []byte{}, - }, - { - name: "Invalid Padding Length Zero (Again)", - input: append([]byte("Test"), byte(0)), - expectError: true, - }, - { - name: "Padding Length Greater Than Input", - input: []byte{10}, - expectError: true, - }, - { - name: "Input Length Not Multiple of Block Size", - input: append([]byte("Invalid Length"), byte(1)), - expected: []byte("Invalid Length"), - }, - { - name: "Valid Padding with Non-ASCII Characters", - input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...), - expected: []byte("こんにちは"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := pkcs5UnPadding(tt.input) - if tt.expectError { - if err == nil { - t.Errorf("Expected error but got nil") - } - } else { - if err != nil { - t.Errorf("Did not expect error but got: %v", err) - } - if !bytes.Equal(result, tt.expected) { - t.Errorf("Expected output %v, got %v", tt.expected, result) - } - } - }) - } -} diff --git a/management/server/activity/store/migration.go b/management/server/activity/store/migration.go index af19a34eb..d0f165d5f 100644 --- a/management/server/activity/store/migration.go +++ b/management/server/activity/store/migration.go @@ -10,9 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/migration" + "github.com/netbirdio/netbird/util/crypt" ) -func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error { +func migrate(ctx context.Context, crypt *crypt.FieldEncrypt, db *gorm.DB) error { migrations := getMigrations(ctx, crypt) for _, m := range migrations { @@ -26,7 +27,7 @@ func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error { type migrationFunc func(*gorm.DB) error -func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc { +func getMigrations(ctx context.Context, crypt *crypt.FieldEncrypt) []migrationFunc { return []migrationFunc{ func(db *gorm.DB) error { return migration.MigrateNewField[activity.DeletedUser](ctx, db, "name", "") @@ -45,7 +46,7 @@ func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc { // migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using // legacy CBC encryption with a static IV to the new GCM encryption method. -func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *FieldEncrypt) error { +func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *crypt.FieldEncrypt) error { model := &activity.DeletedUser{} if !db.Migrator().HasTable(model) { @@ -80,7 +81,7 @@ func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *F return nil } -func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *FieldEncrypt) error { +func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *crypt.FieldEncrypt) error { var err error var decryptedEmail, decryptedName string diff --git a/management/server/activity/store/migration_test.go b/management/server/activity/store/migration_test.go index e3261d9fa..5c6f5ade8 100644 --- a/management/server/activity/store/migration_test.go +++ b/management/server/activity/store/migration_test.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/migration" "github.com/netbirdio/netbird/management/server/testutil" + "github.com/netbirdio/netbird/util/crypt" ) const ( @@ -40,10 +41,10 @@ func setupDatabase(t *testing.T) *gorm.DB { func TestMigrateLegacyEncryptedUsersToGCM(t *testing.T) { db := setupDatabase(t) - key, err := GenerateKey() + key, err := crypt.GenerateKey() require.NoError(t, err, "Failed to generate key") - crypt, err := NewFieldEncrypt(key) + crypt, err := crypt.NewFieldEncrypt(key) require.NoError(t, err, "Failed to initialize FieldEncrypt") t.Run("empty table, no migration required", func(t *testing.T) { diff --git a/management/server/activity/store/sql_store.go b/management/server/activity/store/sql_store.go index ffecb6b8f..db614d0cd 100644 --- a/management/server/activity/store/sql_store.go +++ b/management/server/activity/store/sql_store.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/util/crypt" ) const ( @@ -45,12 +46,12 @@ type eventWithNames struct { // Store is the implementation of the activity.Store interface backed by SQLite type Store struct { db *gorm.DB - fieldEncrypt *FieldEncrypt + fieldEncrypt *crypt.FieldEncrypt } // NewSqlStore creates a new Store with an event table if not exists. func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) { - crypt, err := NewFieldEncrypt(encryptionKey) + fieldEncrypt, err := crypt.NewFieldEncrypt(encryptionKey) if err != nil { return nil, err @@ -61,7 +62,7 @@ func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*St return nil, fmt.Errorf("initialize database: %w", err) } - if err = migrate(ctx, crypt, db); err != nil { + if err = migrate(ctx, fieldEncrypt, db); err != nil { return nil, fmt.Errorf("events database migration: %w", err) } @@ -72,7 +73,7 @@ func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*St return &Store{ db: db, - fieldEncrypt: crypt, + fieldEncrypt: fieldEncrypt, }, nil } diff --git a/management/server/activity/store/sql_store_test.go b/management/server/activity/store/sql_store_test.go index 8c0d159df..d723f1623 100644 --- a/management/server/activity/store/sql_store_test.go +++ b/management/server/activity/store/sql_store_test.go @@ -9,11 +9,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/util/crypt" ) func TestNewSqlStore(t *testing.T) { dataDir := t.TempDir() - key, _ := GenerateKey() + key, _ := crypt.GenerateKey() store, err := NewSqlStore(context.Background(), dataDir, key) if err != nil { t.Fatal(err) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 64f914afe..32a97ff44 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -68,6 +68,13 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks if err := bypass.AddBypassPath("/api/setup"); err != nil { return nil, fmt.Errorf("failed to add bypass path: %w", err) } + // Public invite endpoints (tokens start with nbi_) + if err := bypass.AddBypassPath("/api/users/invites/nbi_*"); err != nil { + return nil, fmt.Errorf("failed to add bypass path: %w", err) + } + if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil { + return nil, fmt.Errorf("failed to add bypass path: %w", err) + } var rateLimitingConfig *middleware.RateLimiterConfig if os.Getenv(rateLimitingEnabledKey) == "true" { @@ -132,6 +139,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router) peers.AddEndpoints(accountManager, router, networkMapController) users.AddEndpoints(accountManager, router) + users.AddInvitesEndpoints(accountManager, router) + users.AddPublicInvitesEndpoints(accountManager, router) setup_keys.AddEndpoints(accountManager, router) policies.AddEndpoints(accountManager, LocationManager, router) policies.AddPostureCheckEndpoints(accountManager, LocationManager, router) @@ -145,6 +154,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks recordsManager.RegisterEndpoints(router, rManager) idp.AddEndpoints(accountManager, router) instance.AddEndpoints(instanceManager, router) + instance.AddVersionEndpoint(instanceManager, router) // Mount embedded IdP handler at /oauth2 path if configured if embeddedIdpEnabled { diff --git a/management/server/http/handlers/instance/instance_handler.go b/management/server/http/handlers/instance/instance_handler.go index 889c3133e..5d8baaf8d 100644 --- a/management/server/http/handlers/instance/instance_handler.go +++ b/management/server/http/handlers/instance/instance_handler.go @@ -28,6 +28,15 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) { router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS") } +// AddVersionEndpoint registers the authenticated version endpoint. +func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) { + h := &handler{ + instanceManager: instanceManager, + } + + router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS") +} + // getInstanceStatus returns the instance status including whether setup is required. // This endpoint is unauthenticated. func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) { @@ -65,3 +74,29 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) { Email: userData.Email, }) } + +// getVersionInfo returns version information for NetBird components. +// This endpoint requires authentication. +func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) { + versionInfo, err := h.instanceManager.GetVersionInfo(r.Context()) + if err != nil { + log.WithContext(r.Context()).Errorf("failed to get version info: %v", err) + util.WriteErrorResponse("failed to get version info", http.StatusInternalServerError, w) + return + } + + resp := api.InstanceVersionInfo{ + ManagementCurrentVersion: versionInfo.CurrentVersion, + ManagementUpdateAvailable: versionInfo.ManagementUpdateAvailable, + } + + if versionInfo.DashboardVersion != "" { + resp.DashboardAvailableVersion = &versionInfo.DashboardVersion + } + + if versionInfo.ManagementVersion != "" { + resp.ManagementAvailableVersion = &versionInfo.ManagementVersion + } + + util.WriteJSONObject(r.Context(), w, resp) +} diff --git a/management/server/http/handlers/instance/instance_handler_test.go b/management/server/http/handlers/instance/instance_handler_test.go index 7a3a2bc88..470079c85 100644 --- a/management/server/http/handlers/instance/instance_handler_test.go +++ b/management/server/http/handlers/instance/instance_handler_test.go @@ -25,6 +25,7 @@ type mockInstanceManager struct { isSetupRequired bool isSetupRequiredFn func(ctx context.Context) (bool, error) createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error) + getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error) } func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) { @@ -66,6 +67,18 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo }, nil } +func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) { + if m.getVersionInfoFn != nil { + return m.getVersionInfoFn(ctx) + } + return &nbinstance.VersionInfo{ + CurrentVersion: "0.34.0", + DashboardVersion: "2.0.0", + ManagementVersion: "0.35.0", + ManagementUpdateAvailable: true, + }, nil +} + var _ nbinstance.Manager = (*mockInstanceManager)(nil) func setupTestRouter(manager nbinstance.Manager) *mux.Router { @@ -279,3 +292,44 @@ func TestSetup_ManagerError(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rec.Code) } + +func TestGetVersionInfo_Success(t *testing.T) { + manager := &mockInstanceManager{} + router := mux.NewRouter() + AddVersionEndpoint(manager, router) + + req := httptest.NewRequest(http.MethodGet, "/instance/version", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response api.InstanceVersionInfo + err := json.NewDecoder(rec.Body).Decode(&response) + require.NoError(t, err) + + assert.Equal(t, "0.34.0", response.ManagementCurrentVersion) + assert.NotNil(t, response.DashboardAvailableVersion) + assert.Equal(t, "2.0.0", *response.DashboardAvailableVersion) + assert.NotNil(t, response.ManagementAvailableVersion) + assert.Equal(t, "0.35.0", *response.ManagementAvailableVersion) + assert.True(t, response.ManagementUpdateAvailable) +} + +func TestGetVersionInfo_Error(t *testing.T) { + manager := &mockInstanceManager{ + getVersionInfoFn: func(ctx context.Context) (*nbinstance.VersionInfo, error) { + return nil, errors.New("failed to fetch versions") + }, + } + router := mux.NewRouter() + AddVersionEndpoint(manager, router) + + req := httptest.NewRequest(http.MethodGet, "/instance/version", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) +} diff --git a/management/server/http/handlers/users/invites_handler.go b/management/server/http/handlers/users/invites_handler.go new file mode 100644 index 000000000..0f0f57c29 --- /dev/null +++ b/management/server/http/handlers/users/invites_handler.go @@ -0,0 +1,263 @@ +package users + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "time" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/account" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +// publicInviteRateLimiter limits public invite requests by IP address to prevent brute-force attacks +var publicInviteRateLimiter = middleware.NewAPIRateLimiter(&middleware.RateLimiterConfig{ + RequestsPerMinute: 10, // 10 attempts per minute per IP + Burst: 5, // Allow burst of 5 requests + CleanupInterval: 10 * time.Minute, + LimiterTTL: 30 * time.Minute, +}) + +// toUserInviteResponse converts a UserInvite to an API response. +func toUserInviteResponse(invite *types.UserInvite) api.UserInvite { + autoGroups := invite.UserInfo.AutoGroups + if autoGroups == nil { + autoGroups = []string{} + } + var inviteLink *string + if invite.InviteToken != "" { + inviteLink = &invite.InviteToken + } + return api.UserInvite{ + Id: invite.UserInfo.ID, + Email: invite.UserInfo.Email, + Name: invite.UserInfo.Name, + Role: invite.UserInfo.Role, + AutoGroups: autoGroups, + ExpiresAt: invite.InviteExpiresAt.UTC(), + CreatedAt: invite.InviteCreatedAt.UTC(), + Expired: time.Now().After(invite.InviteExpiresAt), + InviteToken: inviteLink, + } +} + +// invitesHandler handles user invite operations +type invitesHandler struct { + accountManager account.Manager +} + +// AddInvitesEndpoints registers invite-related endpoints +func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) { + h := &invitesHandler{accountManager: accountManager} + + // Authenticated endpoints (require admin) + router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS") + router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS") + router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS") +} + +// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting +func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Router) { + h := &invitesHandler{accountManager: accountManager} + + // Create a subrouter for public invite endpoints with rate limiting middleware + publicRouter := router.PathPrefix("/users/invites").Subrouter() + publicRouter.Use(publicInviteRateLimiter.Middleware) + + // Public endpoints (no auth required, protected by token and rate limited) + publicRouter.HandleFunc("/{token}", h.getInviteInfo).Methods("GET", "OPTIONS") + publicRouter.HandleFunc("/{token}/accept", h.acceptInvite).Methods("POST", "OPTIONS") +} + +// listInvites handles GET /api/users/invites +func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) { + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := make([]api.UserInvite, 0, len(invites)) + for _, invite := range invites { + resp = append(resp, toUserInviteResponse(invite)) + } + + util.WriteJSONObject(r.Context(), w, resp) +} + +// createInvite handles POST /api/users/invites +func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) { + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.UserInviteCreateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + invite := &types.UserInfo{ + Email: req.Email, + Name: req.Name, + Role: req.Role, + AutoGroups: req.AutoGroups, + } + + expiresIn := 0 + if req.ExpiresIn != nil { + expiresIn = *req.ExpiresIn + } + + result, err := h.accountManager.CreateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, invite, expiresIn) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + result.InviteCreatedAt = time.Now().UTC() + resp := toUserInviteResponse(result) + util.WriteJSONObject(r.Context(), w, &resp) +} + +// getInviteInfo handles GET /api/users/invites/{token} +func (h *invitesHandler) getInviteInfo(w http.ResponseWriter, r *http.Request) { + + vars := mux.Vars(r) + token := vars["token"] + if token == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w) + return + } + + info, err := h.accountManager.GetUserInviteInfo(r.Context(), token) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + expiresAt := info.ExpiresAt.UTC() + util.WriteJSONObject(r.Context(), w, &api.UserInviteInfo{ + Email: info.Email, + Name: info.Name, + ExpiresAt: expiresAt, + Valid: info.Valid, + InvitedBy: info.InvitedBy, + }) +} + +// acceptInvite handles POST /api/users/invites/{token}/accept +func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) { + + vars := mux.Vars(r) + token := vars["token"] + if token == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w) + return + } + + var req api.UserInviteAcceptRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + err := h.accountManager.AcceptUserInvite(r.Context(), token, req.Password) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, &api.UserInviteAcceptResponse{Success: true}) +} + +// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate +func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + inviteID := vars["inviteId"] + if inviteID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w) + return + } + + var req api.UserInviteRegenerateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + // Allow empty body (io.EOF) - expiresIn is optional + if !errors.Is(err, io.EOF) { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + } + + expiresIn := 0 + if req.ExpiresIn != nil { + expiresIn = *req.ExpiresIn + } + + result, err := h.accountManager.RegenerateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID, expiresIn) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + expiresAt := result.InviteExpiresAt.UTC() + util.WriteJSONObject(r.Context(), w, &api.UserInviteRegenerateResponse{ + InviteToken: result.InviteToken, + InviteExpiresAt: expiresAt, + }) +} + +// deleteInvite handles DELETE /api/users/invites/{inviteId} +func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) { + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + inviteID := vars["inviteId"] + if inviteID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w) + return + } + + err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/server/http/handlers/users/invites_handler_test.go b/management/server/http/handlers/users/invites_handler_test.go new file mode 100644 index 000000000..80826b9d4 --- /dev/null +++ b/management/server/http/handlers/users/invites_handler_test.go @@ -0,0 +1,642 @@ +package users + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" +) + +const ( + testAccountID = "test-account-id" + testUserID = "test-user-id" + testInviteID = "test-invite-id" + testInviteToken = "nbi_testtoken123456789012345678" + testEmail = "invite@example.com" + testName = "Test User" +) + +func setupInvitesTestHandler(am *mock_server.MockAccountManager) *invitesHandler { + return &invitesHandler{ + accountManager: am, + } +} + +func TestListInvites(t *testing.T) { + now := time.Now().UTC() + testInvites := []*types.UserInvite{ + { + UserInfo: &types.UserInfo{ + ID: "invite-1", + Email: "user1@example.com", + Name: "User One", + Role: "user", + AutoGroups: []string{"group-1"}, + }, + InviteExpiresAt: now.Add(24 * time.Hour), + InviteCreatedAt: now, + }, + { + UserInfo: &types.UserInfo{ + ID: "invite-2", + Email: "user2@example.com", + Name: "User Two", + Role: "admin", + AutoGroups: nil, + }, + InviteExpiresAt: now.Add(-1 * time.Hour), // Expired + InviteCreatedAt: now.Add(-48 * time.Hour), + }, + } + + tt := []struct { + name string + expectedStatus int + mockFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) + expectedCount int + }{ + { + name: "successful list", + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) { + return testInvites, nil + }, + expectedCount: 2, + }, + { + name: "empty list", + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) { + return []*types.UserInvite{}, nil + }, + expectedCount: 0, + }, + { + name: "permission denied", + expectedStatus: http.StatusForbidden, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) { + return nil, status.NewPermissionDeniedError() + }, + expectedCount: 0, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + ListUserInvitesFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + req := httptest.NewRequest(http.MethodGet, "/api/users/invites", nil) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + + rr := httptest.NewRecorder() + handler.listInvites(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + var resp []api.UserInvite + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + assert.Len(t, resp, tc.expectedCount) + } + }) + } +} + +func TestCreateInvite(t *testing.T) { + now := time.Now().UTC() + expiresAt := now.Add(72 * time.Hour) + + tt := []struct { + name string + requestBody string + expectedStatus int + mockFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) + }{ + { + name: "successful create", + requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":["group-1"]}`, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: testInviteID, + Email: invite.Email, + Name: invite.Name, + Role: invite.Role, + AutoGroups: invite.AutoGroups, + Status: string(types.UserStatusInvited), + }, + InviteToken: testInviteToken, + InviteExpiresAt: expiresAt, + }, nil + }, + }, + { + name: "successful create with custom expiration", + requestBody: `{"email":"test@example.com","name":"Test User","role":"admin","auto_groups":[],"expires_in":3600}`, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + assert.Equal(t, 3600, expiresIn) + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: testInviteID, + Email: invite.Email, + Name: invite.Name, + Role: invite.Role, + AutoGroups: []string{}, + Status: string(types.UserStatusInvited), + }, + InviteToken: testInviteToken, + InviteExpiresAt: expiresAt, + }, nil + }, + }, + { + name: "user already exists", + requestBody: `{"email":"existing@example.com","name":"Existing User","role":"user","auto_groups":[]}`, + expectedStatus: http.StatusConflict, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists") + }, + }, + { + name: "invite already exists", + requestBody: `{"email":"invited@example.com","name":"Invited User","role":"user","auto_groups":[]}`, + expectedStatus: http.StatusConflict, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email") + }, + }, + { + name: "permission denied", + requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`, + expectedStatus: http.StatusForbidden, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return nil, status.NewPermissionDeniedError() + }, + }, + { + name: "embedded IDP not enabled", + requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + }, + }, + { + name: "invalid JSON", + requestBody: `{invalid json}`, + expectedStatus: http.StatusBadRequest, + mockFunc: nil, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + CreateUserInviteFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + req := httptest.NewRequest(http.MethodPost, "/api/users/invites", bytes.NewBufferString(tc.requestBody)) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + + rr := httptest.NewRecorder() + handler.createInvite(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + var resp api.UserInvite + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + assert.Equal(t, testInviteID, resp.Id) + assert.NotNil(t, resp.InviteToken) + assert.NotEmpty(t, *resp.InviteToken) + } + }) + } +} + +func TestGetInviteInfo(t *testing.T) { + now := time.Now().UTC() + + tt := []struct { + name string + token string + expectedStatus int + mockFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error) + }{ + { + name: "successful get valid invite", + token: testInviteToken, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) { + return &types.UserInviteInfo{ + Email: testEmail, + Name: testName, + ExpiresAt: now.Add(24 * time.Hour), + Valid: true, + InvitedBy: "Admin User", + }, nil + }, + }, + { + name: "successful get expired invite", + token: testInviteToken, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) { + return &types.UserInviteInfo{ + Email: testEmail, + Name: testName, + ExpiresAt: now.Add(-24 * time.Hour), + Valid: false, + InvitedBy: "Admin User", + }, nil + }, + }, + { + name: "invite not found", + token: "nbi_invalidtoken1234567890123456", + expectedStatus: http.StatusNotFound, + mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) { + return nil, status.Errorf(status.NotFound, "invite not found") + }, + }, + { + name: "invalid token format", + token: "invalid", + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) { + return nil, status.Errorf(status.InvalidArgument, "invalid invite token") + }, + }, + { + name: "missing token", + token: "", + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: nil, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + GetUserInviteInfoFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + req := httptest.NewRequest(http.MethodGet, "/api/users/invites/"+tc.token, nil) + if tc.token != "" { + req = mux.SetURLVars(req, map[string]string{"token": tc.token}) + } + + rr := httptest.NewRecorder() + handler.getInviteInfo(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + var resp api.UserInviteInfo + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + assert.Equal(t, testEmail, resp.Email) + assert.Equal(t, testName, resp.Name) + } + }) + } +} + +func TestAcceptInvite(t *testing.T) { + tt := []struct { + name string + token string + requestBody string + expectedStatus int + mockFunc func(ctx context.Context, token, password string) error + }{ + { + name: "successful accept", + token: testInviteToken, + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, token, password string) error { + return nil + }, + }, + { + name: "invite not found", + token: "nbi_invalidtoken1234567890123456", + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusNotFound, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.NotFound, "invite not found") + }, + }, + { + name: "invite expired", + token: testInviteToken, + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.InvalidArgument, "invite has expired") + }, + }, + { + name: "embedded IDP not enabled", + token: testInviteToken, + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + }, + }, + { + name: "missing token", + token: "", + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: nil, + }, + { + name: "invalid JSON", + token: testInviteToken, + requestBody: `{invalid}`, + expectedStatus: http.StatusBadRequest, + mockFunc: nil, + }, + { + name: "password too short", + token: testInviteToken, + requestBody: `{"password":"Short1!"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.InvalidArgument, "password must be at least 8 characters long") + }, + }, + { + name: "password missing digit", + token: testInviteToken, + requestBody: `{"password":"NoDigitPass!"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.InvalidArgument, "password must contain at least one digit") + }, + }, + { + name: "password missing uppercase", + token: testInviteToken, + requestBody: `{"password":"nouppercase1!"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.InvalidArgument, "password must contain at least one uppercase letter") + }, + }, + { + name: "password missing special character", + token: testInviteToken, + requestBody: `{"password":"NoSpecial123"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.InvalidArgument, "password must contain at least one special character") + }, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + AcceptUserInviteFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.token+"/accept", bytes.NewBufferString(tc.requestBody)) + if tc.token != "" { + req = mux.SetURLVars(req, map[string]string{"token": tc.token}) + } + + rr := httptest.NewRecorder() + handler.acceptInvite(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + var resp api.UserInviteAcceptResponse + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + assert.True(t, resp.Success) + } + }) + } +} + +func TestRegenerateInvite(t *testing.T) { + now := time.Now().UTC() + expiresAt := now.Add(72 * time.Hour) + + tt := []struct { + name string + inviteID string + requestBody string + expectedStatus int + mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) + }{ + { + name: "successful regenerate with empty body", + inviteID: testInviteID, + requestBody: "", + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + assert.Equal(t, 0, expiresIn) + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: inviteID, + Email: testEmail, + }, + InviteToken: "nbi_newtoken12345678901234567890", + InviteExpiresAt: expiresAt, + }, nil + }, + }, + { + name: "successful regenerate with custom expiration", + inviteID: testInviteID, + requestBody: `{"expires_in":7200}`, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + assert.Equal(t, 7200, expiresIn) + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: inviteID, + Email: testEmail, + }, + InviteToken: "nbi_newtoken12345678901234567890", + InviteExpiresAt: expiresAt, + }, nil + }, + }, + { + name: "invite not found", + inviteID: "non-existent-invite", + requestBody: "", + expectedStatus: http.StatusNotFound, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + return nil, status.Errorf(status.NotFound, "invite not found") + }, + }, + { + name: "permission denied", + inviteID: testInviteID, + requestBody: "", + expectedStatus: http.StatusForbidden, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + return nil, status.NewPermissionDeniedError() + }, + }, + { + name: "missing invite ID", + inviteID: "", + requestBody: "", + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: nil, + }, + { + name: "invalid JSON should return error", + inviteID: testInviteID, + requestBody: `{invalid json}`, + expectedStatus: http.StatusBadRequest, + mockFunc: nil, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + RegenerateUserInviteFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + var body io.Reader + if tc.requestBody != "" { + body = bytes.NewBufferString(tc.requestBody) + } + + req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.inviteID+"/regenerate", body) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + if tc.inviteID != "" { + req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID}) + } + + rr := httptest.NewRecorder() + handler.regenerateInvite(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + var resp api.UserInviteRegenerateResponse + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + assert.NotEmpty(t, resp.InviteToken) + } + }) + } +} + +func TestDeleteInvite(t *testing.T) { + tt := []struct { + name string + inviteID string + expectedStatus int + mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error + }{ + { + name: "successful delete", + inviteID: testInviteID, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + return nil + }, + }, + { + name: "invite not found", + inviteID: "non-existent-invite", + expectedStatus: http.StatusNotFound, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + return status.Errorf(status.NotFound, "invite not found") + }, + }, + { + name: "permission denied", + inviteID: testInviteID, + expectedStatus: http.StatusForbidden, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + return status.NewPermissionDeniedError() + }, + }, + { + name: "embedded IDP not enabled", + inviteID: testInviteID, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + }, + }, + { + name: "missing invite ID", + inviteID: "", + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: nil, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + DeleteUserInviteFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + req := httptest.NewRequest(http.MethodDelete, "/api/users/invites/"+tc.inviteID, nil) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + if tc.inviteID != "" { + req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID}) + } + + rr := httptest.NewRecorder() + handler.deleteInvite(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + }) + } +} diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go index a6266d4f3..936b34319 100644 --- a/management/server/http/middleware/rate_limiter.go +++ b/management/server/http/middleware/rate_limiter.go @@ -2,10 +2,14 @@ package middleware import ( "context" + "net" + "net/http" "sync" "time" "golang.org/x/time/rate" + + "github.com/netbirdio/netbird/shared/management/http/util" ) // RateLimiterConfig holds configuration for the API rate limiter @@ -144,3 +148,25 @@ func (rl *APIRateLimiter) Reset(key string) { defer rl.mu.Unlock() delete(rl.limiters, key) } + +// Middleware returns an HTTP middleware that rate limits requests by client IP. +// Returns 429 Too Many Requests if the rate limit is exceeded. +func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + clientIP := getClientIP(r) + if !rl.Allow(clientIP) { + util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w) + return + } + next.ServeHTTP(w, r) + }) +} + +// getClientIP extracts the client IP address from the request. +func getClientIP(r *http.Request) string { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return ip +} diff --git a/management/server/http/middleware/rate_limiter_test.go b/management/server/http/middleware/rate_limiter_test.go new file mode 100644 index 000000000..68f804e57 --- /dev/null +++ b/management/server/http/middleware/rate_limiter_test.go @@ -0,0 +1,158 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAPIRateLimiter_Allow(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, // 1 per second + Burst: 2, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + // First two requests should be allowed (burst) + assert.True(t, rl.Allow("test-key")) + assert.True(t, rl.Allow("test-key")) + + // Third request should be denied (exceeded burst) + assert.False(t, rl.Allow("test-key")) + + // Different key should be allowed + assert.True(t, rl.Allow("different-key")) +} + +func TestAPIRateLimiter_Middleware(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, // 1 per second + Burst: 2, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + // Create a simple handler that returns 200 OK + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Wrap with rate limiter middleware + handler := rl.Middleware(nextHandler) + + // First two requests should pass (burst) + for i := 0; i < 2; i++ { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1) + } + + // Third request should be rate limited + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusTooManyRequests, rr.Code) +} + +func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := rl.Middleware(nextHandler) + + // Request from first IP + req1 := httptest.NewRequest(http.MethodGet, "/test", nil) + req1.RemoteAddr = "192.168.1.1:12345" + rr1 := httptest.NewRecorder() + handler.ServeHTTP(rr1, req1) + assert.Equal(t, http.StatusOK, rr1.Code) + + // Second request from first IP should be rate limited + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = "192.168.1.1:12345" + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req2) + assert.Equal(t, http.StatusTooManyRequests, rr2.Code) + + // Request from different IP should be allowed + req3 := httptest.NewRequest(http.MethodGet, "/test", nil) + req3.RemoteAddr = "192.168.1.2:12345" + rr3 := httptest.NewRecorder() + handler.ServeHTTP(rr3, req3) + assert.Equal(t, http.StatusOK, rr3.Code) +} + +func TestGetClientIP(t *testing.T) { + tests := []struct { + name string + remoteAddr string + expected string + }{ + { + name: "remote addr with port", + remoteAddr: "192.168.1.1:12345", + expected: "192.168.1.1", + }, + { + name: "remote addr without port", + remoteAddr: "192.168.1.1", + expected: "192.168.1.1", + }, + { + name: "IPv6 with port", + remoteAddr: "[::1]:12345", + expected: "::1", + }, + { + name: "IPv6 without port", + remoteAddr: "::1", + expected: "::1", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = tc.remoteAddr + assert.Equal(t, tc.expected, getClientIP(req)) + }) + } +} + +func TestAPIRateLimiter_Reset(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + // Use up the burst + assert.True(t, rl.Allow("test-key")) + assert.False(t, rl.Allow("test-key")) + + // Reset the limiter + rl.Reset("test-key") + + // Should be allowed again + assert.True(t, rl.Allow("test-key")) +} diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go index 79859525b..db7a91fa3 100644 --- a/management/server/idp/embedded.go +++ b/management/server/idp/embedded.go @@ -20,7 +20,7 @@ const ( staticClientCLI = "netbird-cli" defaultCLIRedirectURL1 = "http://localhost:53000/" defaultCLIRedirectURL2 = "http://localhost:54000/" - defaultScopes = "openid profile email" + defaultScopes = "openid profile email groups" defaultUserIDClaim = "sub" ) diff --git a/management/server/instance/manager.go b/management/server/instance/manager.go index 6f50e3ff7..6a0509ebd 100644 --- a/management/server/instance/manager.go +++ b/management/server/instance/manager.go @@ -2,18 +2,54 @@ package instance import ( "context" + "encoding/json" "errors" "fmt" + "io" + "net/http" "net/mail" + "strings" "sync" + "time" + goversion "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/version" ) +const ( + // Version endpoints + managementVersionURL = "https://pkgs.netbird.io/releases/latest/version" + dashboardReleasesURL = "https://api.github.com/repos/netbirdio/dashboard/releases/latest" + + // Cache TTL for version information + versionCacheTTL = 60 * time.Minute + + // HTTP client timeout + httpTimeout = 5 * time.Second +) + +// VersionInfo contains version information for NetBird components +type VersionInfo struct { + // CurrentVersion is the running management server version + CurrentVersion string + // DashboardVersion is the latest available dashboard version from GitHub + DashboardVersion string + // ManagementVersion is the latest available management version from GitHub + ManagementVersion string + // ManagementUpdateAvailable indicates if a newer management version is available + ManagementUpdateAvailable bool +} + +// githubRelease represents a GitHub release response +type githubRelease struct { + TagName string `json:"tag_name"` +} + // Manager handles instance-level operations like initial setup. type Manager interface { // IsSetupRequired checks if instance setup is required. @@ -23,6 +59,9 @@ type Manager interface { // CreateOwnerUser creates the initial owner user in the embedded IDP. // This should only be called when IsSetupRequired returns true. CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) + + // GetVersionInfo returns version information for NetBird components. + GetVersionInfo(ctx context.Context) (*VersionInfo, error) } // DefaultManager is the default implementation of Manager. @@ -32,6 +71,12 @@ type DefaultManager struct { setupRequired bool setupMu sync.RWMutex + + // Version caching + httpClient *http.Client + versionMu sync.RWMutex + cachedVersions *VersionInfo + lastVersionFetch time.Time } // NewManager creates a new instance manager. @@ -43,6 +88,9 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) store: store, embeddedIdpManager: embeddedIdp, setupRequired: false, + httpClient: &http.Client{ + Timeout: httpTimeout, + }, } if embeddedIdp != nil { @@ -134,3 +182,130 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error { } return nil } + +// GetVersionInfo returns version information for NetBird components. +func (m *DefaultManager) GetVersionInfo(ctx context.Context) (*VersionInfo, error) { + m.versionMu.RLock() + if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL { + cached := *m.cachedVersions + m.versionMu.RUnlock() + return &cached, nil + } + m.versionMu.RUnlock() + + return m.fetchVersionInfo(ctx) +} + +func (m *DefaultManager) fetchVersionInfo(ctx context.Context) (*VersionInfo, error) { + m.versionMu.Lock() + // Double-check after acquiring write lock + if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL { + cached := *m.cachedVersions + m.versionMu.Unlock() + return &cached, nil + } + m.versionMu.Unlock() + + info := &VersionInfo{ + CurrentVersion: version.NetbirdVersion(), + } + + // Fetch management version from pkgs.netbird.io (plain text) + mgmtVersion, err := m.fetchPlainTextVersion(ctx, managementVersionURL) + if err != nil { + log.WithContext(ctx).Warnf("failed to fetch management version: %v", err) + } else { + info.ManagementVersion = mgmtVersion + info.ManagementUpdateAvailable = isNewerVersion(info.CurrentVersion, mgmtVersion) + } + + // Fetch dashboard version from GitHub + dashVersion, err := m.fetchGitHubRelease(ctx, dashboardReleasesURL) + if err != nil { + log.WithContext(ctx).Warnf("failed to fetch dashboard version from GitHub: %v", err) + } else { + info.DashboardVersion = dashVersion + } + + // Update cache + m.versionMu.Lock() + m.cachedVersions = info + m.lastVersionFetch = time.Now() + m.versionMu.Unlock() + + return info, nil +} + +// isNewerVersion returns true if latestVersion is greater than currentVersion +func isNewerVersion(currentVersion, latestVersion string) bool { + current, err := goversion.NewVersion(currentVersion) + if err != nil { + return false + } + + latest, err := goversion.NewVersion(latestVersion) + if err != nil { + return false + } + + return latest.GreaterThan(current) +} + +func (m *DefaultManager) fetchPlainTextVersion(ctx context.Context, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + + req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion()) + + resp, err := m.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 100)) + if err != nil { + return "", fmt.Errorf("read response: %w", err) + } + + return strings.TrimSpace(string(body)), nil +} + +func (m *DefaultManager) fetchGitHubRelease(ctx context.Context, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion()) + + resp, err := m.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var release githubRelease + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return "", fmt.Errorf("decode response: %w", err) + } + + // Remove 'v' prefix if present + tag := release.TagName + if len(tag) > 0 && tag[0] == 'v' { + tag = tag[1:] + } + + return tag, nil +} diff --git a/management/server/instance/version_test.go b/management/server/instance/version_test.go new file mode 100644 index 000000000..35ba66db8 --- /dev/null +++ b/management/server/instance/version_test.go @@ -0,0 +1,285 @@ +package instance + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockRoundTripper implements http.RoundTripper for testing +type mockRoundTripper struct { + callCount atomic.Int32 + managementVersion string + dashboardVersion string +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + m.callCount.Add(1) + + var body string + if strings.Contains(req.URL.String(), "pkgs.netbird.io") { + // Plain text response for management version + body = m.managementVersion + } else if strings.Contains(req.URL.String(), "github.com") { + // JSON response for dashboard version + jsonResp, _ := json.Marshal(githubRelease{TagName: "v" + m.dashboardVersion}) + body = string(jsonResp) + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(body)), + Header: make(http.Header), + }, nil +} + +func TestDefaultManager_GetVersionInfo_ReturnsCurrentVersion(t *testing.T) { + mockTransport := &mockRoundTripper{ + managementVersion: "0.65.0", + dashboardVersion: "2.10.0", + } + + m := &DefaultManager{ + httpClient: &http.Client{Transport: mockTransport}, + } + + ctx := context.Background() + + info, err := m.GetVersionInfo(ctx) + require.NoError(t, err) + + // CurrentVersion should always be set + assert.NotEmpty(t, info.CurrentVersion) + assert.Equal(t, "0.65.0", info.ManagementVersion) + assert.Equal(t, "2.10.0", info.DashboardVersion) + assert.Equal(t, int32(2), mockTransport.callCount.Load()) // 2 calls: management + dashboard +} + +func TestDefaultManager_GetVersionInfo_CachesResults(t *testing.T) { + mockTransport := &mockRoundTripper{ + managementVersion: "0.65.0", + dashboardVersion: "2.10.0", + } + + m := &DefaultManager{ + httpClient: &http.Client{Transport: mockTransport}, + } + + ctx := context.Background() + + // First call + info1, err := m.GetVersionInfo(ctx) + require.NoError(t, err) + assert.NotEmpty(t, info1.CurrentVersion) + assert.Equal(t, "0.65.0", info1.ManagementVersion) + + initialCallCount := mockTransport.callCount.Load() + + // Second call should use cache (no additional HTTP calls) + info2, err := m.GetVersionInfo(ctx) + require.NoError(t, err) + assert.Equal(t, info1.CurrentVersion, info2.CurrentVersion) + assert.Equal(t, info1.ManagementVersion, info2.ManagementVersion) + assert.Equal(t, info1.DashboardVersion, info2.DashboardVersion) + + // Verify no additional HTTP calls were made (cache was used) + assert.Equal(t, initialCallCount, mockTransport.callCount.Load()) +} + +func TestDefaultManager_FetchGitHubRelease_ParsesTagName(t *testing.T) { + tests := []struct { + name string + tagName string + expected string + shouldError bool + }{ + { + name: "tag with v prefix", + tagName: "v1.2.3", + expected: "1.2.3", + }, + { + name: "tag without v prefix", + tagName: "1.2.3", + expected: "1.2.3", + }, + { + name: "tag with prerelease", + tagName: "v2.0.0-beta.1", + expected: "2.0.0-beta.1", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(githubRelease{TagName: tc.tagName}) + })) + defer server.Close() + + m := &DefaultManager{ + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + + version, err := m.fetchGitHubRelease(context.Background(), server.URL) + + if tc.shouldError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expected, version) + } + }) + } +} + +func TestDefaultManager_FetchGitHubRelease_HandlesErrors(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + }{ + { + name: "not found", + statusCode: http.StatusNotFound, + body: `{"message": "Not Found"}`, + }, + { + name: "rate limited", + statusCode: http.StatusForbidden, + body: `{"message": "API rate limit exceeded"}`, + }, + { + name: "server error", + statusCode: http.StatusInternalServerError, + body: `{"message": "Internal Server Error"}`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.statusCode) + _, _ = w.Write([]byte(tc.body)) + })) + defer server.Close() + + m := &DefaultManager{ + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + + _, err := m.fetchGitHubRelease(context.Background(), server.URL) + assert.Error(t, err) + }) + } +} + +func TestDefaultManager_FetchGitHubRelease_InvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{invalid json}`)) + })) + defer server.Close() + + m := &DefaultManager{ + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + + _, err := m.fetchGitHubRelease(context.Background(), server.URL) + assert.Error(t, err) +} + +func TestDefaultManager_FetchGitHubRelease_ContextCancellation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1 * time.Second) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(githubRelease{TagName: "v1.0.0"}) + })) + defer server.Close() + + m := &DefaultManager{ + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + _, err := m.fetchGitHubRelease(ctx, server.URL) + assert.Error(t, err) +} + +func TestIsNewerVersion(t *testing.T) { + tests := []struct { + name string + currentVersion string + latestVersion string + expected bool + }{ + { + name: "latest is newer - minor version", + currentVersion: "0.64.1", + latestVersion: "0.65.0", + expected: true, + }, + { + name: "latest is newer - patch version", + currentVersion: "0.64.1", + latestVersion: "0.64.2", + expected: true, + }, + { + name: "latest is newer - major version", + currentVersion: "0.64.1", + latestVersion: "1.0.0", + expected: true, + }, + { + name: "versions are equal", + currentVersion: "0.64.1", + latestVersion: "0.64.1", + expected: false, + }, + { + name: "current is newer - minor version", + currentVersion: "0.65.0", + latestVersion: "0.64.1", + expected: false, + }, + { + name: "current is newer - patch version", + currentVersion: "0.64.2", + latestVersion: "0.64.1", + expected: false, + }, + { + name: "development version", + currentVersion: "development", + latestVersion: "0.65.0", + expected: false, + }, + { + name: "invalid latest version", + currentVersion: "0.64.1", + latestVersion: "invalid", + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isNewerVersion(tc.currentVersion, tc.latestVersion) + assert.Equal(t, tc.expected, result) + }) + } +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 75e971498..026989898 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -139,6 +139,12 @@ type MockAccountManager struct { CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) + CreateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) + AcceptUserInviteFunc func(ctx context.Context, token, password string) error + RegenerateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) + GetUserInviteInfoFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error) + ListUserInvitesFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) + DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error } func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error { @@ -713,6 +719,48 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } +func (am *MockAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + if am.CreateUserInviteFunc != nil { + return am.CreateUserInviteFunc(ctx, accountID, initiatorUserID, invite, expiresIn) + } + return nil, status.Errorf(codes.Unimplemented, "method CreateUserInvite is not implemented") +} + +func (am *MockAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error { + if am.AcceptUserInviteFunc != nil { + return am.AcceptUserInviteFunc(ctx, token, password) + } + return status.Errorf(codes.Unimplemented, "method AcceptUserInvite is not implemented") +} + +func (am *MockAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + if am.RegenerateUserInviteFunc != nil { + return am.RegenerateUserInviteFunc(ctx, accountID, initiatorUserID, inviteID, expiresIn) + } + return nil, status.Errorf(codes.Unimplemented, "method RegenerateUserInvite is not implemented") +} + +func (am *MockAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) { + if am.GetUserInviteInfoFunc != nil { + return am.GetUserInviteInfoFunc(ctx, token) + } + return nil, status.Errorf(codes.Unimplemented, "method GetUserInviteInfo is not implemented") +} + +func (am *MockAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) { + if am.ListUserInvitesFunc != nil { + return am.ListUserInvitesFunc(ctx, accountID, initiatorUserID) + } + return nil, status.Errorf(codes.Unimplemented, "method ListUserInvites is not implemented") +} + +func (am *MockAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + if am.DeleteUserInviteFunc != nil { + return am.DeleteUserInviteFunc(ctx, accountID, initiatorUserID, inviteID) + } + return status.Errorf(codes.Unimplemented, "method DeleteUserInvite is not implemented") +} + func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) { if am.GetAccountIDFromUserAuthFunc != nil { return am.GetAccountIDFromUserAuthFunc(ctx, userAuth) diff --git a/management/server/peer.go b/management/server/peer.go index d6eb2aecd..80c74e209 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -728,6 +728,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return fmt.Errorf("failed adding peer to All group: %w", err) } + if temporary { + // we should track ephemeral peers to be able to clean them if the peer don't sync and be marked as connected + am.networkMapController.TrackEphemeralPeer(ctx, newPeer) + } + if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 0eb687dbb..7f48f510e 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -126,7 +126,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, - &types.Job{}, &zones.Zone{}, &records.Record{}, + &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, ) if err != nil { return nil, fmt.Errorf("auto migratePreAuto: %w", err) @@ -815,6 +815,130 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre return &user, nil } +// SaveUserInvite saves a user invite to the database +func (s *SqlStore) SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error { + inviteCopy := invite.Copy() + if err := inviteCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { + return fmt.Errorf("encrypt invite: %w", err) + } + + result := s.db.Save(inviteCopy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save user invite to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save user invite to store") + } + return nil +} + +// GetUserInviteByID retrieves a user invite by its ID and account ID +func (s *SqlStore) GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var invite types.UserInviteRecord + result := tx.Where("account_id = ?", accountID).Take(&invite, idQueryCondition, inviteID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "user invite not found") + } + log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get user invite from store") + } + + if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt invite: %w", err) + } + + return &invite, nil +} + +// GetUserInviteByHashedToken retrieves a user invite by its hashed token +func (s *SqlStore) GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var invite types.UserInviteRecord + result := tx.Take(&invite, "hashed_token = ?", hashedToken) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "user invite not found") + } + log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get user invite from store") + } + + if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt invite: %w", err) + } + + return &invite, nil +} + +// GetUserInviteByEmail retrieves a user invite by account ID and email. +// Since email is encrypted with random IVs, we fetch all invites for the account +// and compare emails in memory after decryption. +func (s *SqlStore) GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var invites []*types.UserInviteRecord + result := tx.Find(&invites, "account_id = ?", accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get user invites from store") + } + + for _, invite := range invites { + if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt invite: %w", err) + } + if strings.EqualFold(invite.Email, email) { + return invite, nil + } + } + + return nil, status.Errorf(status.NotFound, "user invite not found for email") +} + +// GetAccountUserInvites retrieves all user invites for an account +func (s *SqlStore) GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var invites []*types.UserInviteRecord + result := tx.Find(&invites, "account_id = ?", accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get user invites from store") + } + + for _, invite := range invites { + if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt invite: %w", err) + } + } + + return invites, nil +} + +// DeleteUserInvite deletes a user invite by its ID +func (s *SqlStore) DeleteUserInvite(ctx context.Context, inviteID string) error { + result := s.db.Delete(&types.UserInviteRecord{}, idQueryCondition, inviteID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete user invite from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete user invite from store") + } + return nil +} + func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) { tx := s.db if lockStrength != LockingStrengthNone { @@ -4269,6 +4393,9 @@ func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingS Take(&userID, GetKeyQueryCondition(s), peerKey) if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.Errorf(status.NotFound, "peer not found: index lookup failed") + } return "", status.Errorf(status.Internal, "failed to get user ID by peer key") } diff --git a/management/server/store/sql_store_user_invite_test.go b/management/server/store/sql_store_user_invite_test.go new file mode 100644 index 000000000..fb6934a2e --- /dev/null +++ b/management/server/store/sql_store_user_invite_test.go @@ -0,0 +1,520 @@ +package store + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/types" +) + +func TestSqlStore_SaveUserInvite(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-1", + AccountID: "account-1", + Email: "test@example.com", + Name: "Test User", + Role: "user", + AutoGroups: []string{"group-1", "group-2"}, + HashedToken: "hashed-token-123", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Verify the invite was saved + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + assert.Equal(t, invite.Email, retrieved.Email) + assert.Equal(t, invite.Name, retrieved.Name) + assert.Equal(t, invite.Role, retrieved.Role) + assert.Equal(t, invite.AutoGroups, retrieved.AutoGroups) + assert.Equal(t, invite.CreatedBy, retrieved.CreatedBy) + }) +} + +func TestSqlStore_SaveUserInvite_Update(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-update", + AccountID: "account-1", + Email: "test@example.com", + Name: "Test User", + Role: "user", + AutoGroups: []string{"group-1"}, + HashedToken: "hashed-token-123", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Update the invite with a new token + invite.HashedToken = "new-hashed-token" + invite.ExpiresAt = time.Now().Add(24 * time.Hour) + + err = store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Verify the update + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + assert.Equal(t, "new-hashed-token", retrieved.HashedToken) + }) +} + +func TestSqlStore_GetUserInviteByID(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-get-by-id", + AccountID: "account-1", + Email: "getbyid@example.com", + Name: "Get By ID User", + Role: "admin", + AutoGroups: []string{}, + HashedToken: "hashed-token-get", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Get by ID - success + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + assert.Equal(t, invite.Email, retrieved.Email) + + // Get by ID - wrong account + _, err = store.GetUserInviteByID(ctx, LockingStrengthNone, "wrong-account", invite.ID) + assert.Error(t, err) + + // Get by ID - not found + _, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, "non-existent") + assert.Error(t, err) + }) +} + +func TestSqlStore_GetUserInviteByHashedToken(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-get-by-token", + AccountID: "account-1", + Email: "getbytoken@example.com", + Name: "Get By Token User", + Role: "user", + AutoGroups: []string{"group-1"}, + HashedToken: "unique-hashed-token-456", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Get by hashed token - success + retrieved, err := store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, invite.HashedToken) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + assert.Equal(t, invite.Email, retrieved.Email) + + // Get by hashed token - not found + _, err = store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, "non-existent-token") + assert.Error(t, err) + }) +} + +func TestSqlStore_GetUserInviteByEmail(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-get-by-email", + AccountID: "account-email-test", + Email: "unique-email@example.com", + Name: "Get By Email User", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-email", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Get by email - success + retrieved, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, invite.Email) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + + // Get by email - case insensitive + retrieved, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "UNIQUE-EMAIL@EXAMPLE.COM") + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + + // Get by email - wrong account + _, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, "wrong-account", invite.Email) + assert.Error(t, err) + + // Get by email - not found + _, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "nonexistent@example.com") + assert.Error(t, err) + }) +} + +func TestSqlStore_GetAccountUserInvites(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + accountID := "account-list-invites" + + invites := []*types.UserInviteRecord{ + { + ID: "invite-list-1", + AccountID: accountID, + Email: "user1@example.com", + Name: "User One", + Role: "user", + AutoGroups: []string{"group-1"}, + HashedToken: "hashed-token-list-1", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + }, + { + ID: "invite-list-2", + AccountID: accountID, + Email: "user2@example.com", + Name: "User Two", + Role: "admin", + AutoGroups: []string{"group-2"}, + HashedToken: "hashed-token-list-2", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + }, + { + ID: "invite-list-3", + AccountID: "different-account", + Email: "user3@example.com", + Name: "User Three", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-list-3", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + }, + } + + for _, invite := range invites { + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + } + + // Get all invites for the account + retrieved, err := store.GetAccountUserInvites(ctx, LockingStrengthNone, accountID) + require.NoError(t, err) + assert.Len(t, retrieved, 2) + + // Verify the invites belong to the correct account + for _, invite := range retrieved { + assert.Equal(t, accountID, invite.AccountID) + } + + // Get invites for account with no invites + retrieved, err = store.GetAccountUserInvites(ctx, LockingStrengthNone, "empty-account") + require.NoError(t, err) + assert.Len(t, retrieved, 0) + }) +} + +func TestSqlStore_DeleteUserInvite(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-delete", + AccountID: "account-delete-test", + Email: "delete@example.com", + Name: "Delete User", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-delete", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Verify invite exists + _, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + + // Delete the invite + err = store.DeleteUserInvite(ctx, invite.ID) + require.NoError(t, err) + + // Verify invite is deleted + _, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + assert.Error(t, err) + }) +} + +func TestSqlStore_UserInvite_EncryptedFields(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-encrypted", + AccountID: "account-encrypted", + Email: "sensitive-email@example.com", + Name: "Sensitive Name", + Role: "user", + AutoGroups: []string{"group-1"}, + HashedToken: "hashed-token-encrypted", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Retrieve and verify decryption works + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + assert.Equal(t, "sensitive-email@example.com", retrieved.Email) + assert.Equal(t, "Sensitive Name", retrieved.Name) + }) +} + +func TestSqlStore_DeleteUserInvite_NonExistent(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + // Deleting a non-existent invite should not return an error + err := store.DeleteUserInvite(ctx, "non-existent-invite-id") + require.NoError(t, err) + }) +} + +func TestSqlStore_UserInvite_SameEmailDifferentAccounts(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + email := "shared-email@example.com" + + // Create invite in first account + invite1 := &types.UserInviteRecord{ + ID: "invite-account1", + AccountID: "account-1", + Email: email, + Name: "User Account 1", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-account1", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-1", + } + + // Create invite in second account with same email + invite2 := &types.UserInviteRecord{ + ID: "invite-account2", + AccountID: "account-2", + Email: email, + Name: "User Account 2", + Role: "admin", + AutoGroups: []string{"group-1"}, + HashedToken: "hashed-token-account2", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-2", + } + + err := store.SaveUserInvite(ctx, invite1) + require.NoError(t, err) + + err = store.SaveUserInvite(ctx, invite2) + require.NoError(t, err) + + // Verify each account gets the correct invite by email + retrieved1, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-1", email) + require.NoError(t, err) + assert.Equal(t, "invite-account1", retrieved1.ID) + assert.Equal(t, "User Account 1", retrieved1.Name) + + retrieved2, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-2", email) + require.NoError(t, err) + assert.Equal(t, "invite-account2", retrieved2.ID) + assert.Equal(t, "User Account 2", retrieved2.Name) + }) +} + +func TestSqlStore_UserInvite_LockingStrength(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-locking", + AccountID: "account-locking", + Email: "locking@example.com", + Name: "Locking Test User", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-locking", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Test with different locking strengths + lockStrengths := []LockingStrength{LockingStrengthNone, LockingStrengthShare, LockingStrengthUpdate} + + for _, strength := range lockStrengths { + retrieved, err := store.GetUserInviteByID(ctx, strength, invite.AccountID, invite.ID) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + + retrieved, err = store.GetUserInviteByHashedToken(ctx, strength, invite.HashedToken) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + + retrieved, err = store.GetUserInviteByEmail(ctx, strength, invite.AccountID, invite.Email) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + + invites, err := store.GetAccountUserInvites(ctx, strength, invite.AccountID) + require.NoError(t, err) + assert.Len(t, invites, 1) + } + }) +} + +func TestSqlStore_UserInvite_EmptyAutoGroups(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + // Test with nil AutoGroups + invite := &types.UserInviteRecord{ + ID: "invite-nil-autogroups", + AccountID: "account-autogroups", + Email: "nilgroups@example.com", + Name: "Nil Groups User", + Role: "user", + AutoGroups: nil, + HashedToken: "hashed-token-nil", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + // Should return empty slice or nil, both are acceptable + assert.Empty(t, retrieved.AutoGroups) + }) +} + +func TestSqlStore_UserInvite_TimestampPrecision(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + now := time.Now().UTC().Truncate(time.Millisecond) + expiresAt := now.Add(72 * time.Hour) + + invite := &types.UserInviteRecord{ + ID: "invite-timestamp", + AccountID: "account-timestamp", + Email: "timestamp@example.com", + Name: "Timestamp User", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-timestamp", + ExpiresAt: expiresAt, + CreatedAt: now, + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + + // Verify timestamps are preserved (within reasonable precision) + assert.WithinDuration(t, now, retrieved.CreatedAt, time.Second) + assert.WithinDuration(t, expiresAt, retrieved.ExpiresAt, time.Second) + }) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 02c746592..be0d29768 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -92,6 +92,13 @@ type Store interface { DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error + SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error + GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error) + GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error) + GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error) + GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error) + DeleteUserInvite(ctx context.Context, inviteID string) error + GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) diff --git a/management/server/types/user_invite.go b/management/server/types/user_invite.go new file mode 100644 index 000000000..1544b0ff3 --- /dev/null +++ b/management/server/types/user_invite.go @@ -0,0 +1,201 @@ +package types + +import ( + "crypto/sha256" + b64 "encoding/base64" + "fmt" + "hash/crc32" + "strings" + "time" + + b "github.com/hashicorp/go-secure-stdlib/base62" + "github.com/rs/xid" + + "github.com/netbirdio/netbird/base62" + "github.com/netbirdio/netbird/util/crypt" +) + +const ( + // InviteTokenPrefix is the prefix for invite tokens + InviteTokenPrefix = "nbi_" + // InviteTokenSecretLength is the length of the random secret part + InviteTokenSecretLength = 30 + // InviteTokenChecksumLength is the length of the encoded checksum + InviteTokenChecksumLength = 6 + // InviteTokenLength is the total length of the token (4 + 30 + 6 = 40) + InviteTokenLength = 40 + // DefaultInviteExpirationSeconds is the default expiration time for invites (72 hours) + DefaultInviteExpirationSeconds = 259200 + // MinInviteExpirationSeconds is the minimum expiration time for invites (1 hour) + MinInviteExpirationSeconds = 3600 +) + +// UserInviteRecord represents an invitation for a user to set up their account (database model) +type UserInviteRecord struct { + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index;not null"` + Email string `gorm:"index;not null"` + Name string `gorm:"not null"` + Role string `gorm:"not null"` + AutoGroups []string `gorm:"serializer:json"` + HashedToken string `gorm:"index;not null"` // SHA-256 hash of the token (base64 encoded) + ExpiresAt time.Time `gorm:"not null"` + CreatedAt time.Time `gorm:"not null"` + CreatedBy string `gorm:"not null"` +} + +// TableName returns the table name for GORM +func (UserInviteRecord) TableName() string { + return "user_invites" +} + +// GenerateInviteToken creates a new invite token with the format: nbi_ +// Returns the hashed token (for storage) and the plain token (to give to the user) +func GenerateInviteToken() (hashedToken string, plainToken string, err error) { + secret, err := b.Random(InviteTokenSecretLength) + if err != nil { + return "", "", fmt.Errorf("failed to generate random secret: %w", err) + } + + checksum := crc32.ChecksumIEEE([]byte(secret)) + encodedChecksum := base62.Encode(checksum) + // Left-pad with '0' to ensure exactly 6 characters (fmt.Sprintf %s pads with spaces which breaks base62.Decode) + paddedChecksum := encodedChecksum + if len(paddedChecksum) < InviteTokenChecksumLength { + paddedChecksum = strings.Repeat("0", InviteTokenChecksumLength-len(paddedChecksum)) + paddedChecksum + } + + plainToken = InviteTokenPrefix + secret + paddedChecksum + hash := sha256.Sum256([]byte(plainToken)) + hashedToken = b64.StdEncoding.EncodeToString(hash[:]) + + return hashedToken, plainToken, nil +} + +// HashInviteToken creates a SHA-256 hash of the token (base64 encoded) +func HashInviteToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return b64.StdEncoding.EncodeToString(hash[:]) +} + +// ValidateInviteToken validates the token format and checksum. +// Returns an error if the token is invalid. +func ValidateInviteToken(token string) error { + if len(token) != InviteTokenLength { + return fmt.Errorf("invalid token length") + } + + prefix := token[:len(InviteTokenPrefix)] + if prefix != InviteTokenPrefix { + return fmt.Errorf("invalid token prefix") + } + + secret := token[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength] + encodedChecksum := token[len(InviteTokenPrefix)+InviteTokenSecretLength:] + + verificationChecksum, err := base62.Decode(encodedChecksum) + if err != nil { + return fmt.Errorf("checksum decoding failed: %w", err) + } + + secretChecksum := crc32.ChecksumIEEE([]byte(secret)) + if secretChecksum != verificationChecksum { + return fmt.Errorf("checksum does not match") + } + + return nil +} + +// IsExpired checks if the invite has expired +func (i *UserInviteRecord) IsExpired() bool { + return time.Now().After(i.ExpiresAt) +} + +// UserInvite contains the result of creating or regenerating an invite +type UserInvite struct { + UserInfo *UserInfo + InviteToken string + InviteExpiresAt time.Time + InviteCreatedAt time.Time +} + +// UserInviteInfo contains public information about an invite (for unauthenticated endpoint) +type UserInviteInfo struct { + Email string `json:"email"` + Name string `json:"name"` + ExpiresAt time.Time `json:"expires_at"` + Valid bool `json:"valid"` + InvitedBy string `json:"invited_by"` +} + +// NewInviteID generates a new invite ID using xid +func NewInviteID() string { + return xid.New().String() +} + +// EncryptSensitiveData encrypts the invite's sensitive fields (Email and Name) in place. +func (i *UserInviteRecord) EncryptSensitiveData(enc *crypt.FieldEncrypt) error { + if enc == nil { + return nil + } + + var err error + if i.Email != "" { + i.Email, err = enc.Encrypt(i.Email) + if err != nil { + return fmt.Errorf("encrypt email: %w", err) + } + } + + if i.Name != "" { + i.Name, err = enc.Encrypt(i.Name) + if err != nil { + return fmt.Errorf("encrypt name: %w", err) + } + } + + return nil +} + +// DecryptSensitiveData decrypts the invite's sensitive fields (Email and Name) in place. +func (i *UserInviteRecord) DecryptSensitiveData(enc *crypt.FieldEncrypt) error { + if enc == nil { + return nil + } + + var err error + if i.Email != "" { + i.Email, err = enc.Decrypt(i.Email) + if err != nil { + return fmt.Errorf("decrypt email: %w", err) + } + } + + if i.Name != "" { + i.Name, err = enc.Decrypt(i.Name) + if err != nil { + return fmt.Errorf("decrypt name: %w", err) + } + } + + return nil +} + +// Copy creates a deep copy of the UserInviteRecord +func (i *UserInviteRecord) Copy() *UserInviteRecord { + autoGroups := make([]string, len(i.AutoGroups)) + copy(autoGroups, i.AutoGroups) + + return &UserInviteRecord{ + ID: i.ID, + AccountID: i.AccountID, + Email: i.Email, + Name: i.Name, + Role: i.Role, + AutoGroups: autoGroups, + HashedToken: i.HashedToken, + ExpiresAt: i.ExpiresAt, + CreatedAt: i.CreatedAt, + CreatedBy: i.CreatedBy, + } +} diff --git a/management/server/types/user_invite_test.go b/management/server/types/user_invite_test.go new file mode 100644 index 000000000..09dae3800 --- /dev/null +++ b/management/server/types/user_invite_test.go @@ -0,0 +1,355 @@ +package types + +import ( + "crypto/sha256" + b64 "encoding/base64" + "hash/crc32" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/base62" + "github.com/netbirdio/netbird/util/crypt" +) + +func TestUserInviteRecord_TableName(t *testing.T) { + invite := UserInviteRecord{} + assert.Equal(t, "user_invites", invite.TableName()) +} + +func TestGenerateInviteToken_Success(t *testing.T) { + hashedToken, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + assert.NotEmpty(t, hashedToken) + assert.NotEmpty(t, plainToken) +} + +func TestGenerateInviteToken_Length(t *testing.T) { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + assert.Len(t, plainToken, InviteTokenLength) +} + +func TestGenerateInviteToken_Prefix(t *testing.T) { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + assert.True(t, strings.HasPrefix(plainToken, InviteTokenPrefix)) +} + +func TestGenerateInviteToken_Hashing(t *testing.T) { + hashedToken, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + expectedHash := sha256.Sum256([]byte(plainToken)) + expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:]) + assert.Equal(t, expectedHashedToken, hashedToken) +} + +func TestGenerateInviteToken_Checksum(t *testing.T) { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + // Extract parts + secret := plainToken[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength] + checksumStr := plainToken[len(InviteTokenPrefix)+InviteTokenSecretLength:] + + // Verify checksum + expectedChecksum := crc32.ChecksumIEEE([]byte(secret)) + actualChecksum, err := base62.Decode(checksumStr) + require.NoError(t, err) + assert.Equal(t, expectedChecksum, actualChecksum) +} + +func TestGenerateInviteToken_Uniqueness(t *testing.T) { + tokens := make(map[string]bool) + for i := 0; i < 100; i++ { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + assert.False(t, tokens[plainToken], "Token should be unique") + tokens[plainToken] = true + } +} + +func TestHashInviteToken(t *testing.T) { + token := "nbi_testtoken123456789012345678901234" + hashedToken := HashInviteToken(token) + + expectedHash := sha256.Sum256([]byte(token)) + expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:]) + assert.Equal(t, expectedHashedToken, hashedToken) +} + +func TestHashInviteToken_Consistency(t *testing.T) { + token := "nbi_testtoken123456789012345678901234" + hash1 := HashInviteToken(token) + hash2 := HashInviteToken(token) + assert.Equal(t, hash1, hash2) +} + +func TestHashInviteToken_DifferentTokens(t *testing.T) { + token1 := "nbi_testtoken123456789012345678901234" + token2 := "nbi_testtoken123456789012345678901235" + hash1 := HashInviteToken(token1) + hash2 := HashInviteToken(token2) + assert.NotEqual(t, hash1, hash2) +} + +func TestValidateInviteToken_Success(t *testing.T) { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + err = ValidateInviteToken(plainToken) + assert.NoError(t, err) +} + +func TestValidateInviteToken_InvalidLength(t *testing.T) { + testCases := []struct { + name string + token string + }{ + {"empty", ""}, + {"too short", "nbi_abc"}, + {"too long", "nbi_" + strings.Repeat("a", 50)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateInviteToken(tc.token) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid token length") + }) + } +} + +func TestValidateInviteToken_InvalidPrefix(t *testing.T) { + // Create a token with wrong prefix but correct length + token := "xyz_" + strings.Repeat("a", 30) + "000000" + err := ValidateInviteToken(token) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid token prefix") +} + +func TestValidateInviteToken_InvalidChecksum(t *testing.T) { + // Create a token with correct format but invalid checksum + token := InviteTokenPrefix + strings.Repeat("a", InviteTokenSecretLength) + "ZZZZZZ" + err := ValidateInviteToken(token) + require.Error(t, err) + assert.Contains(t, err.Error(), "checksum") +} + +func TestValidateInviteToken_ModifiedToken(t *testing.T) { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + // Modify one character in the secret part + modifiedToken := plainToken[:5] + "X" + plainToken[6:] + err = ValidateInviteToken(modifiedToken) + require.Error(t, err) +} + +func TestUserInviteRecord_IsExpired(t *testing.T) { + t.Run("not expired", func(t *testing.T) { + invite := &UserInviteRecord{ + ExpiresAt: time.Now().Add(time.Hour), + } + assert.False(t, invite.IsExpired()) + }) + + t.Run("expired", func(t *testing.T) { + invite := &UserInviteRecord{ + ExpiresAt: time.Now().Add(-time.Hour), + } + assert.True(t, invite.IsExpired()) + }) + + t.Run("just expired", func(t *testing.T) { + invite := &UserInviteRecord{ + ExpiresAt: time.Now().Add(-time.Second), + } + assert.True(t, invite.IsExpired()) + }) +} + +func TestNewInviteID(t *testing.T) { + id := NewInviteID() + assert.NotEmpty(t, id) + assert.Len(t, id, 20) // xid generates 20 character IDs +} + +func TestNewInviteID_Uniqueness(t *testing.T) { + ids := make(map[string]bool) + for i := 0; i < 100; i++ { + id := NewInviteID() + assert.False(t, ids[id], "ID should be unique") + ids[id] = true + } +} + +func TestUserInviteRecord_EncryptDecryptSensitiveData(t *testing.T) { + key, err := crypt.GenerateKey() + require.NoError(t, err) + fieldEncrypt, err := crypt.NewFieldEncrypt(key) + require.NoError(t, err) + + t.Run("encrypt and decrypt", func(t *testing.T) { + invite := &UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "test@example.com", + Name: "Test User", + Role: "user", + } + + // Encrypt + err := invite.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Verify encrypted values are different from original + assert.NotEqual(t, "test@example.com", invite.Email) + assert.NotEqual(t, "Test User", invite.Name) + + // Decrypt + err = invite.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Verify decrypted values match original + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + }) + + t.Run("encrypt empty fields", func(t *testing.T) { + invite := &UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "", + Name: "", + Role: "user", + } + + err := invite.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + assert.Equal(t, "", invite.Email) + assert.Equal(t, "", invite.Name) + + err = invite.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + assert.Equal(t, "", invite.Email) + assert.Equal(t, "", invite.Name) + }) + + t.Run("nil encryptor", func(t *testing.T) { + invite := &UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "test@example.com", + Name: "Test User", + Role: "user", + } + + err := invite.EncryptSensitiveData(nil) + require.NoError(t, err) + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + + err = invite.DecryptSensitiveData(nil) + require.NoError(t, err) + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + }) +} + +func TestUserInviteRecord_Copy(t *testing.T) { + now := time.Now() + expiresAt := now.Add(72 * time.Hour) + + original := &UserInviteRecord{ + ID: "invite-id", + AccountID: "account-id", + Email: "test@example.com", + Name: "Test User", + Role: "user", + AutoGroups: []string{"group1", "group2"}, + HashedToken: "hashed-token", + ExpiresAt: expiresAt, + CreatedAt: now, + CreatedBy: "creator-id", + } + + copied := original.Copy() + + // Verify all fields are copied + assert.Equal(t, original.ID, copied.ID) + assert.Equal(t, original.AccountID, copied.AccountID) + assert.Equal(t, original.Email, copied.Email) + assert.Equal(t, original.Name, copied.Name) + assert.Equal(t, original.Role, copied.Role) + assert.Equal(t, original.AutoGroups, copied.AutoGroups) + assert.Equal(t, original.HashedToken, copied.HashedToken) + assert.Equal(t, original.ExpiresAt, copied.ExpiresAt) + assert.Equal(t, original.CreatedAt, copied.CreatedAt) + assert.Equal(t, original.CreatedBy, copied.CreatedBy) + + // Verify deep copy of AutoGroups (modifying copy doesn't affect original) + copied.AutoGroups[0] = "modified" + assert.NotEqual(t, original.AutoGroups[0], copied.AutoGroups[0]) + assert.Equal(t, "group1", original.AutoGroups[0]) +} + +func TestUserInviteRecord_Copy_EmptyAutoGroups(t *testing.T) { + original := &UserInviteRecord{ + ID: "invite-id", + AccountID: "account-id", + AutoGroups: []string{}, + } + + copied := original.Copy() + assert.NotNil(t, copied.AutoGroups) + assert.Len(t, copied.AutoGroups, 0) +} + +func TestUserInviteRecord_Copy_NilAutoGroups(t *testing.T) { + original := &UserInviteRecord{ + ID: "invite-id", + AccountID: "account-id", + AutoGroups: nil, + } + + copied := original.Copy() + assert.NotNil(t, copied.AutoGroups) + assert.Len(t, copied.AutoGroups, 0) +} + +func TestInviteTokenConstants(t *testing.T) { + // Verify constants are consistent + expectedLength := len(InviteTokenPrefix) + InviteTokenSecretLength + InviteTokenChecksumLength + assert.Equal(t, InviteTokenLength, expectedLength) + assert.Equal(t, 4, len(InviteTokenPrefix)) + assert.Equal(t, 30, InviteTokenSecretLength) + assert.Equal(t, 6, InviteTokenChecksumLength) + assert.Equal(t, 40, InviteTokenLength) + assert.Equal(t, 259200, DefaultInviteExpirationSeconds) // 72 hours + assert.Equal(t, 3600, MinInviteExpirationSeconds) // 1 hour +} + +func TestGenerateInviteToken_ValidatesOwnOutput(t *testing.T) { + // Generate multiple tokens and ensure they all validate + for i := 0; i < 50; i++ { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + err = ValidateInviteToken(plainToken) + assert.NoError(t, err, "Generated token should always be valid") + } +} + +func TestHashInviteToken_MatchesGeneratedHash(t *testing.T) { + hashedToken, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + // HashInviteToken should produce the same hash as GenerateInviteToken + rehashedToken := HashInviteToken(plainToken) + assert.Equal(t, hashedToken, rehashedToken) +} diff --git a/management/server/user.go b/management/server/user.go index 1f38b749f..51da7a633 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" "time" + "unicode" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth" @@ -704,7 +705,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, } eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupAddedToUser, meta) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, meta) }) } @@ -718,7 +719,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, } eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta) }) } @@ -1282,7 +1283,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI addPeerRemovedEvent() } - meta := map[string]any{"name": targetUserInfo.Name, "email": targetUserInfo.Email, "created_at": targetUser.CreatedAt} + meta := map[string]any{"name": targetUserInfo.Name, "email": targetUserInfo.Email, "created_at": targetUser.CreatedAt, "issued": targetUser.Issued} am.StoreEvent(ctx, initiatorUserID, targetUser.Id, accountID, activity.UserDeleted, meta) return updateAccountPeers, nil @@ -1453,3 +1454,368 @@ func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, init return nil } + +// CreateUserInvite creates an invite link for a new user in the embedded IdP. +// The user is NOT created until the invite is accepted. +func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + if !IsEmbeddedIdp(am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + } + + if err := validateUserInvite(invite); err != nil { + return nil, err + } + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + // Check if user already exists in NetBird DB + existingUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, err + } + for _, user := range existingUsers { + if strings.EqualFold(user.Email, invite.Email) { + return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists") + } + } + + // Check if invite already exists for this email + existingInvite, err := am.Store.GetUserInviteByEmail(ctx, store.LockingStrengthNone, accountID, invite.Email) + if err != nil { + if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { + return nil, fmt.Errorf("failed to check existing invites: %w", err) + } + } + if existingInvite != nil { + return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email") + } + + // Calculate expiration time + if expiresIn <= 0 { + expiresIn = types.DefaultInviteExpirationSeconds + } + + if expiresIn < types.MinInviteExpirationSeconds { + return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour") + } + expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second) + + // Generate invite token + inviteID := types.NewInviteID() + hashedToken, plainToken, err := types.GenerateInviteToken() + if err != nil { + return nil, fmt.Errorf("failed to generate invite token: %w", err) + } + + // Create the invite record (no user created yet) + userInvite := &types.UserInviteRecord{ + ID: inviteID, + AccountID: accountID, + Email: invite.Email, + Name: invite.Name, + Role: invite.Role, + AutoGroups: invite.AutoGroups, + HashedToken: hashedToken, + ExpiresAt: expiresAt, + CreatedAt: time.Now().UTC(), + CreatedBy: initiatorUserID, + } + + if err := am.Store.SaveUserInvite(ctx, userInvite); err != nil { + return nil, err + } + + am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkCreated, map[string]any{"email": invite.Email}) + + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: inviteID, + Email: invite.Email, + Name: invite.Name, + Role: invite.Role, + AutoGroups: invite.AutoGroups, + Status: string(types.UserStatusInvited), + Issued: types.UserIssuedAPI, + }, + InviteToken: plainToken, + InviteExpiresAt: expiresAt, + }, nil +} + +// GetUserInviteInfo retrieves invite information from a token (public endpoint). +func (am *DefaultAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) { + if err := types.ValidateInviteToken(token); err != nil { + return nil, status.Errorf(status.InvalidArgument, "invalid invite token: %v", err) + } + + hashedToken := types.HashInviteToken(token) + invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthNone, hashedToken) + if err != nil { + return nil, err + } + + // Get the inviter's name + invitedBy := "" + if invite.CreatedBy != "" { + inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, invite.CreatedBy) + if err == nil && inviter != nil { + invitedBy = inviter.Name + } + } + + return &types.UserInviteInfo{ + Email: invite.Email, + Name: invite.Name, + ExpiresAt: invite.ExpiresAt, + Valid: !invite.IsExpired(), + InvitedBy: invitedBy, + }, nil +} + +// ListUserInvites returns all invites for an account. +func (am *DefaultAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) { + if !IsEmbeddedIdp(am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + } + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + records, err := am.Store.GetAccountUserInvites(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, err + } + + invites := make([]*types.UserInvite, 0, len(records)) + for _, record := range records { + invites = append(invites, &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: record.ID, + Email: record.Email, + Name: record.Name, + Role: record.Role, + AutoGroups: record.AutoGroups, + }, + InviteExpiresAt: record.ExpiresAt, + InviteCreatedAt: record.CreatedAt, + }) + } + + return invites, nil +} + +// AcceptUserInvite accepts an invite and creates the user in both IdP and NetBird DB. +func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error { + if !IsEmbeddedIdp(am.idpManager) { + return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + } + + if password == "" { + return status.Errorf(status.InvalidArgument, "password is required") + } + + if err := validatePassword(password); err != nil { + return status.Errorf(status.InvalidArgument, "invalid password: %v", err) + } + + if err := types.ValidateInviteToken(token); err != nil { + return status.Errorf(status.InvalidArgument, "invalid invite token: %v", err) + } + + hashedToken := types.HashInviteToken(token) + invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthUpdate, hashedToken) + if err != nil { + return err + } + + if invite.IsExpired() { + return status.Errorf(status.InvalidArgument, "invite has expired") + } + + // Create user in Dex with the provided password + embeddedIdp, ok := am.idpManager.(*idp.EmbeddedIdPManager) + if !ok { + return status.Errorf(status.Internal, "failed to get embedded IdP manager") + } + + idpUser, err := embeddedIdp.CreateUserWithPassword(ctx, invite.Email, password, invite.Name) + if err != nil { + return fmt.Errorf("failed to create user in IdP: %w", err) + } + + // Create user in NetBird DB + newUser := &types.User{ + Id: idpUser.ID, + AccountID: invite.AccountID, + Role: types.StrRoleToUserRole(invite.Role), + AutoGroups: invite.AutoGroups, + Issued: types.UserIssuedAPI, + CreatedAt: time.Now().UTC(), + Email: invite.Email, + Name: invite.Name, + } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err := transaction.SaveUser(ctx, newUser); err != nil { + return fmt.Errorf("failed to save user: %w", err) + } + if err := transaction.DeleteUserInvite(ctx, invite.ID); err != nil { + return fmt.Errorf("failed to delete invite: %w", err) + } + return nil + }) + if err != nil { + // Best-effort rollback: delete the IdP user to avoid orphaned records + if deleteErr := embeddedIdp.DeleteUser(ctx, idpUser.ID); deleteErr != nil { + log.WithContext(ctx).WithError(deleteErr).Errorf("failed to rollback IdP user %s after transaction failure", idpUser.ID) + } + return err + } + + am.StoreEvent(ctx, newUser.Id, newUser.Id, invite.AccountID, activity.UserInviteLinkAccepted, map[string]any{"email": invite.Email}) + + return nil +} + +// RegenerateUserInvite creates a new invite token for an existing invite, invalidating the previous one. +func (am *DefaultAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + if !IsEmbeddedIdp(am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + } + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + // Get existing invite + existingInvite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID) + if err != nil { + return nil, err + } + + // Calculate expiration time + if expiresIn <= 0 { + expiresIn = types.DefaultInviteExpirationSeconds + } + if expiresIn < types.MinInviteExpirationSeconds { + return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour") + } + expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second) + + // Generate new invite token + hashedToken, plainToken, err := types.GenerateInviteToken() + if err != nil { + return nil, fmt.Errorf("failed to generate invite token: %w", err) + } + + // Update existing invite with new token and expiration + existingInvite.HashedToken = hashedToken + existingInvite.ExpiresAt = expiresAt + existingInvite.CreatedBy = initiatorUserID + + err = am.Store.SaveUserInvite(ctx, existingInvite) + if err != nil { + return nil, err + } + + am.StoreEvent(ctx, initiatorUserID, existingInvite.ID, accountID, activity.UserInviteLinkRegenerated, map[string]any{"email": existingInvite.Email}) + + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: existingInvite.ID, + Email: existingInvite.Email, + Name: existingInvite.Name, + Role: existingInvite.Role, + AutoGroups: existingInvite.AutoGroups, + Status: string(types.UserStatusInvited), + Issued: types.UserIssuedAPI, + }, + InviteToken: plainToken, + InviteExpiresAt: expiresAt, + }, nil +} + +// DeleteUserInvite deletes an existing invite by ID. +func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + if !IsEmbeddedIdp(am.idpManager) { + return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + } + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + invite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID) + if err != nil { + return err + } + + if err := am.Store.DeleteUserInvite(ctx, inviteID); err != nil { + return err + } + + am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkDeleted, map[string]any{"email": invite.Email}) + + return nil +} + +const minPasswordLength = 8 + +// validatePassword checks password strength requirements: +// - Minimum 8 characters +// - At least 1 digit +// - At least 1 uppercase letter +// - At least 1 special character +func validatePassword(password string) error { + if len(password) < minPasswordLength { + return errors.New("password must be at least 8 characters long") + } + + var hasDigit, hasUpper, hasSpecial bool + for _, c := range password { + switch { + case unicode.IsDigit(c): + hasDigit = true + case unicode.IsUpper(c): + hasUpper = true + case !unicode.IsLetter(c) && !unicode.IsDigit(c): + hasSpecial = true + } + } + + var missing []string + if !hasDigit { + missing = append(missing, "one digit") + } + if !hasUpper { + missing = append(missing, "one uppercase letter") + } + if !hasSpecial { + missing = append(missing, "one special character") + } + + if len(missing) > 0 { + return errors.New("password must contain at least " + strings.Join(missing, ", ")) + } + + return nil +} diff --git a/management/server/user_invite_test.go b/management/server/user_invite_test.go new file mode 100644 index 000000000..6256ed44a --- /dev/null +++ b/management/server/user_invite_test.go @@ -0,0 +1,1010 @@ +package server + +import ( + "context" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/util/crypt" +) + +const ( + testAccountID = "testAccountID" + testAdminUserID = "testAdminUserID" + testRegularUserID = "testRegularUserID" +) + +// setupInviteTestManagerWithEmbeddedIdP creates a test manager with a real embedded IdP +// and store encryption enabled. This is required for tests that need to pass the IsEmbeddedIdp check. +func setupInviteTestManagerWithEmbeddedIdP(t *testing.T) (*DefaultAccountManager, func()) { + t.Helper() + ctx := context.Background() + + tmpDir := t.TempDir() + dexDataDir := tmpDir + "/dex" + require.NoError(t, os.MkdirAll(dexDataDir, 0700)) + + // Create test store + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", tmpDir) + require.NoError(t, err, "Error when creating store") + + // Enable encryption + key, err := crypt.GenerateKey() + require.NoError(t, err) + fieldEncrypt, err := crypt.NewFieldEncrypt(key) + require.NoError(t, err) + s.SetFieldEncrypt(fieldEncrypt) + + // Create embedded IDP config + embeddedIdPConfig := &idp.EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: idp.EmbeddedStorageConfig{ + Type: "sqlite3", + Config: idp.EmbeddedStorageTypeConfig{ + File: dexDataDir + "/dex.db", + }, + }, + } + + // Create embedded IDP manager + embeddedIdp, err := idp.NewEmbeddedIdPManager(ctx, embeddedIdPConfig, nil) + require.NoError(t, err) + + account := newAccountWithId(ctx, testAccountID, testAdminUserID, "", "admin@test.com", "Admin User", false) + account.Users[testRegularUserID] = &types.User{ + Id: testRegularUserID, + AccountID: testAccountID, + Role: types.UserRoleUser, + Email: "regular@test.com", + Name: "Regular User", + } + + err = s.SaveAccount(ctx, account) + require.NoError(t, err, "Error when saving account") + + permissionsManager := permissions.NewManager(s) + + am := DefaultAccountManager{ + Store: s, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + idpManager: embeddedIdp, + } + + cleanupFunc := func() { + _ = embeddedIdp.Stop(ctx) + cleanup() + } + + return &am, cleanupFunc +} + +func TestCreateUserInvite_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + require.NotNil(t, result) + + assert.Equal(t, "newuser@test.com", result.UserInfo.Email) + assert.Equal(t, "New User", result.UserInfo.Name) + assert.Equal(t, "user", result.UserInfo.Role) + assert.Equal(t, string(types.UserStatusInvited), result.UserInfo.Status) + assert.NotEmpty(t, result.InviteToken) + assert.True(t, result.InviteExpiresAt.After(time.Now())) + + // Verify invite is stored in DB + invites, err := am.Store.GetAccountUserInvites(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + assert.Len(t, invites, 1) + assert.Equal(t, "newuser@test.com", invites[0].Email) +} + +func TestCreateUserInvite_DuplicateEmail(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + // Create first invite + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Try to create duplicate invite + _, err = am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.AlreadyExists, sErr.Type()) +} + +func TestCreateUserInvite_ExistingUserEmail(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Try to invite with an email that already exists as a user + invite := &types.UserInfo{ + Email: "regular@test.com", // Already exists as a user + Name: "Duplicate User", + Role: "user", + AutoGroups: []string{}, + } + + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.UserAlreadyExists, sErr.Type()) +} + +func TestCreateUserInvite_PermissionDenied(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + // Regular user should not be able to create invites + _, err := am.CreateUserInvite(context.Background(), testAccountID, testRegularUserID, invite, 0) + require.Error(t, err) +} + +func TestCreateUserInvite_InvalidEmail(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) +} + +func TestCreateUserInvite_InvalidName(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "", + Role: "user", + AutoGroups: []string{}, + } + + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) +} + +func TestCreateUserInvite_OwnerRole(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newowner@test.com", + Name: "New Owner", + Role: "owner", + AutoGroups: []string{}, + } + + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) +} + +func TestCreateUserInvite_ExpirationTooShort(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + // Try to create with expiration less than 1 hour + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 1800) // 30 minutes + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) + assert.Contains(t, err.Error(), "at least 1 hour") +} + +func TestCreateUserInvite_CustomExpiration(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + expiresIn := 7200 // 2 hours + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, expiresIn) + require.NoError(t, err) + + // Verify expiration is approximately 2 hours from now + expectedExpiration := time.Now().Add(time.Duration(expiresIn) * time.Second) + assert.WithinDuration(t, expectedExpiration, result.InviteExpiresAt, time.Minute) +} + +func TestCreateUserInvite_WithAutoGroups(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{"group1", "group2"}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + assert.Equal(t, []string{"group1", "group2"}, result.UserInfo.AutoGroups) + + // Verify invite in DB has auto groups + invites, err := am.Store.GetAccountUserInvites(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + require.Len(t, invites, 1) + assert.Equal(t, []string{"group1", "group2"}, invites[0].AutoGroups) +} + +func TestGetUserInviteInfo_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite first + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Get the invite info using the token + info, err := am.GetUserInviteInfo(context.Background(), result.InviteToken) + require.NoError(t, err) + require.NotNil(t, info) + + assert.Equal(t, "newuser@test.com", info.Email) + assert.Equal(t, "New User", info.Name) + assert.True(t, info.Valid) + assert.Equal(t, "Admin User", info.InvitedBy) +} + +func TestGetUserInviteInfo_InvalidToken(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + _, err := am.GetUserInviteInfo(context.Background(), "invalid_token") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) +} + +func TestGetUserInviteInfo_TokenNotFound(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Generate a valid token format that doesn't exist in DB + _, validToken, err := types.GenerateInviteToken() + require.NoError(t, err) + + _, err = am.GetUserInviteInfo(context.Background(), validToken) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.NotFound, sErr.Type()) +} + +func TestGetUserInviteInfo_ExpiredInvite(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite with valid expiration + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Manually set the invite to expired by updating the store directly + inviteRecord, err := am.Store.GetUserInviteByID(context.Background(), store.LockingStrengthUpdate, testAccountID, result.UserInfo.ID) + require.NoError(t, err) + inviteRecord.ExpiresAt = time.Now().Add(-time.Hour) // Set to 1 hour ago + err = am.Store.SaveUserInvite(context.Background(), inviteRecord) + require.NoError(t, err) + + // Get the invite info - should still return info but Valid should be false + info, err := am.GetUserInviteInfo(context.Background(), result.InviteToken) + require.NoError(t, err) + assert.False(t, info.Valid) +} + +func TestListUserInvites_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create multiple invites + for i, email := range []string{"user1@test.com", "user2@test.com", "user3@test.com"} { + invite := &types.UserInfo{ + Email: email, + Name: "User " + string(rune('1'+i)), + Role: "user", + AutoGroups: []string{}, + } + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + } + + // List invites + invites, err := am.ListUserInvites(context.Background(), testAccountID, testAdminUserID) + require.NoError(t, err) + assert.Len(t, invites, 3) +} + +func TestListUserInvites_Empty(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invites, err := am.ListUserInvites(context.Background(), testAccountID, testAdminUserID) + require.NoError(t, err) + assert.Len(t, invites, 0) +} + +func TestListUserInvites_PermissionDenied(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + _, err := am.ListUserInvites(context.Background(), testAccountID, testRegularUserID) + require.Error(t, err) +} + +func TestRegenerateUserInvite_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite first + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + originalResult, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Regenerate the invite + newResult, err := am.RegenerateUserInvite(context.Background(), testAccountID, testAdminUserID, originalResult.UserInfo.ID, 0) + require.NoError(t, err) + require.NotNil(t, newResult) + + // Verify invite ID remains the same (stable ID for clients) + assert.Equal(t, originalResult.UserInfo.ID, newResult.UserInfo.ID) + + // Verify new token is different + assert.NotEqual(t, originalResult.InviteToken, newResult.InviteToken) + assert.Equal(t, "newuser@test.com", newResult.UserInfo.Email) + assert.Equal(t, "New User", newResult.UserInfo.Name) + + // Verify old token no longer works + _, err = am.GetUserInviteInfo(context.Background(), originalResult.InviteToken) + require.Error(t, err) + + // Verify new token works + info, err := am.GetUserInviteInfo(context.Background(), newResult.InviteToken) + require.NoError(t, err) + assert.Equal(t, "newuser@test.com", info.Email) +} + +func TestRegenerateUserInvite_NotFound(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + _, err := am.RegenerateUserInvite(context.Background(), testAccountID, testAdminUserID, "nonexistent-id", 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.NotFound, sErr.Type()) +} + +func TestRegenerateUserInvite_PermissionDenied(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite first + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Regular user should not be able to regenerate + _, err = am.RegenerateUserInvite(context.Background(), testAccountID, testRegularUserID, result.UserInfo.ID, 0) + require.Error(t, err) +} + +func TestDeleteUserInvite_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite first + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Delete the invite + err = am.DeleteUserInvite(context.Background(), testAccountID, testAdminUserID, result.UserInfo.ID) + require.NoError(t, err) + + // Verify invite is deleted + invites, err := am.Store.GetAccountUserInvites(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + assert.Len(t, invites, 0) + + // Verify token no longer works + _, err = am.GetUserInviteInfo(context.Background(), result.InviteToken) + require.Error(t, err) +} + +func TestDeleteUserInvite_NotFound(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + err := am.DeleteUserInvite(context.Background(), testAccountID, testAdminUserID, "nonexistent-id") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.NotFound, sErr.Type()) +} + +func TestDeleteUserInvite_PermissionDenied(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite first + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Regular user should not be able to delete + err = am.DeleteUserInvite(context.Background(), testAccountID, testRegularUserID, result.UserInfo.ID) + require.Error(t, err) +} + +func TestDeleteUserInvite_WrongAccount(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Create another account + anotherAccountID := "anotherAccountID" + anotherAdminID := "anotherAdminID" + anotherAccount := newAccountWithId(context.Background(), anotherAccountID, anotherAdminID, "", "otheradmin@test.com", "Other Admin", false) + err = am.Store.SaveAccount(context.Background(), anotherAccount) + require.NoError(t, err) + + // Try to delete from wrong account + err = am.DeleteUserInvite(context.Background(), anotherAccountID, anotherAdminID, result.UserInfo.ID) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.NotFound, sErr.Type()) +} + +func TestAcceptUserInvite_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Accept the invite with a valid password + err = am.AcceptUserInvite(context.Background(), result.InviteToken, "Password1!") + require.NoError(t, err) + + // Verify user is created in DB + users, err := am.Store.GetAccountUsers(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + + var foundUser *types.User + for _, u := range users { + if u.Email == "newuser@test.com" { + foundUser = u + break + } + } + require.NotNil(t, foundUser, "User should be created in DB") + assert.Equal(t, "New User", foundUser.Name) + assert.Equal(t, types.UserRoleUser, foundUser.Role) + + // Verify invite is deleted + invites, err := am.Store.GetAccountUserInvites(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + assert.Len(t, invites, 0) +} + +func TestAcceptUserInvite_InvalidToken(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + err := am.AcceptUserInvite(context.Background(), "invalid_token", "Password1!") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) +} + +func TestAcceptUserInvite_TokenNotFound(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Generate a valid token format that doesn't exist in DB + _, validToken, err := types.GenerateInviteToken() + require.NoError(t, err) + + err = am.AcceptUserInvite(context.Background(), validToken, "Password1!") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.NotFound, sErr.Type()) +} + +func TestAcceptUserInvite_ExpiredToken(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite with valid expiration + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Manually set the invite to expired by updating the store directly + inviteRecord, err := am.Store.GetUserInviteByID(context.Background(), store.LockingStrengthUpdate, testAccountID, result.UserInfo.ID) + require.NoError(t, err) + inviteRecord.ExpiresAt = time.Now().Add(-time.Hour) // Set to 1 hour ago + err = am.Store.SaveUserInvite(context.Background(), inviteRecord) + require.NoError(t, err) + + err = am.AcceptUserInvite(context.Background(), result.InviteToken, "Password1!") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) + assert.Contains(t, err.Error(), "expired") +} + +func TestAcceptUserInvite_EmptyPassword(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + err = am.AcceptUserInvite(context.Background(), result.InviteToken, "") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) + assert.Contains(t, err.Error(), "password is required") +} + +func TestAcceptUserInvite_WeakPassword(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + testCases := []struct { + name string + password string + expectedMsg string + }{ + {"too short", "Pass1!", "at least 8 characters"}, + {"no digit", "Password!", "one digit"}, + {"no uppercase", "password1!", "one uppercase"}, + {"no special", "Password1", "one special character"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := am.AcceptUserInvite(context.Background(), result.InviteToken, tc.password) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedMsg) + }) + } +} + +func TestValidatePassword(t *testing.T) { + testCases := []struct { + name string + password string + expectError bool + errorMsg string + }{ + {"valid password", "Password1!", false, ""}, + {"valid complex password", "MyP@ssw0rd#2024", false, ""}, + {"too short", "Pass1!", true, "at least 8 characters"}, + {"no digit", "Password!", true, "one digit"}, + {"no uppercase", "password1!", true, "one uppercase"}, + {"no special", "Password1", true, "one special character"}, + {"only lowercase", "password", true, "one digit"}, + {"no uppercase no special", "password1", true, "one uppercase"}, + {"all lowercase short", "pass", true, "at least 8 characters"}, + {"empty", "", true, "at least 8 characters"}, + {"spaces count as special", "Pass word1", false, ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validatePassword(tc.password) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestInviteToken_GenerateAndValidate(t *testing.T) { + hashedToken, plainToken, err := types.GenerateInviteToken() + require.NoError(t, err) + require.NotEmpty(t, hashedToken) + require.NotEmpty(t, plainToken) + + // Validate token format + assert.Len(t, plainToken, types.InviteTokenLength) + assert.True(t, len(plainToken) > len(types.InviteTokenPrefix)) + assert.Equal(t, types.InviteTokenPrefix, plainToken[:len(types.InviteTokenPrefix)]) + + // Validate checksum + err = types.ValidateInviteToken(plainToken) + require.NoError(t, err) + + // Verify hashing is consistent + hashedAgain := types.HashInviteToken(plainToken) + assert.Equal(t, hashedToken, hashedAgain) +} + +func TestInviteToken_ValidateInvalid(t *testing.T) { + testCases := []struct { + name string + token string + }{ + {"empty", ""}, + {"too short", "nbi_abc"}, + {"wrong prefix", "xyz_123456789012345678901234567890"}, + {"invalid checksum", "nbi_123456789012345678901234567890abcdef"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := types.ValidateInviteToken(tc.token) + require.Error(t, err) + }) + } +} + +func TestUserInviteRecord_IsExpired(t *testing.T) { + // Not expired + invite := &types.UserInviteRecord{ + ExpiresAt: time.Now().Add(time.Hour), + } + assert.False(t, invite.IsExpired()) + + // Expired + invite = &types.UserInviteRecord{ + ExpiresAt: time.Now().Add(-time.Hour), + } + assert.True(t, invite.IsExpired()) +} + +func TestUserInviteRecord_Copy(t *testing.T) { + original := &types.UserInviteRecord{ + ID: "invite-id", + AccountID: "account-id", + Email: "test@example.com", + Name: "Test User", + Role: "user", + AutoGroups: []string{"group1", "group2"}, + HashedToken: "hashed-token", + ExpiresAt: time.Now().Add(time.Hour), + CreatedAt: time.Now(), + CreatedBy: "creator-id", + } + + copied := original.Copy() + + assert.Equal(t, original.ID, copied.ID) + assert.Equal(t, original.AccountID, copied.AccountID) + assert.Equal(t, original.Email, copied.Email) + assert.Equal(t, original.Name, copied.Name) + assert.Equal(t, original.Role, copied.Role) + assert.Equal(t, original.AutoGroups, copied.AutoGroups) + assert.Equal(t, original.HashedToken, copied.HashedToken) + assert.Equal(t, original.ExpiresAt, copied.ExpiresAt) + assert.Equal(t, original.CreatedAt, copied.CreatedAt) + assert.Equal(t, original.CreatedBy, copied.CreatedBy) + + // Verify deep copy of AutoGroups + copied.AutoGroups[0] = "modified" + assert.NotEqual(t, original.AutoGroups[0], copied.AutoGroups[0]) +} + +func TestCreateUserInvite_NonEmbeddedIdP(t *testing.T) { + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + account := newAccountWithId(context.Background(), testAccountID, testAdminUserID, "", "admin@test.com", "Admin User", false) + err = s.SaveAccount(context.Background(), account) + require.NoError(t, err) + + permissionsManager := permissions.NewManager(s) + + // Use nil IDP manager (non-embedded) + am := DefaultAccountManager{ + Store: s, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + idpManager: nil, + } + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + _, err = am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.PreconditionFailed, sErr.Type()) + assert.Contains(t, err.Error(), "embedded identity provider") +} + +func TestAcceptUserInvite_WithAutoGroups(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite with auto groups + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "admin", + AutoGroups: []string{"group1", "group2"}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Accept the invite + err = am.AcceptUserInvite(context.Background(), result.InviteToken, "Password1!") + require.NoError(t, err) + + // Verify user has the auto groups and role + users, err := am.Store.GetAccountUsers(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + + var foundUser *types.User + for _, u := range users { + if u.Email == "newuser@test.com" { + foundUser = u + break + } + } + require.NotNil(t, foundUser) + assert.Equal(t, types.UserRoleAdmin, foundUser.Role) + assert.Equal(t, []string{"group1", "group2"}, foundUser.AutoGroups) +} + +func TestUserInvite_EncryptDecryptSensitiveData(t *testing.T) { + key, err := crypt.GenerateKey() + require.NoError(t, err) + fieldEncrypt, err := crypt.NewFieldEncrypt(key) + require.NoError(t, err) + + t.Run("encrypt and decrypt", func(t *testing.T) { + invite := &types.UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "test@example.com", + Name: "Test User", + Role: "user", + } + + // Encrypt + err := invite.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Verify encrypted values are different from original + assert.NotEqual(t, "test@example.com", invite.Email) + assert.NotEqual(t, "Test User", invite.Name) + + // Decrypt + err = invite.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Verify decrypted values match original + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + }) + + t.Run("encrypt empty fields", func(t *testing.T) { + invite := &types.UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "", + Name: "", + Role: "user", + } + + // Encrypt empty fields + err := invite.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Empty strings should remain empty + assert.Equal(t, "", invite.Email) + assert.Equal(t, "", invite.Name) + + // Decrypt empty fields + err = invite.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Should still be empty + assert.Equal(t, "", invite.Email) + assert.Equal(t, "", invite.Name) + }) + + t.Run("nil encryptor", func(t *testing.T) { + invite := &types.UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "test@example.com", + Name: "Test User", + Role: "user", + } + + // Encrypt with nil encryptor should be no-op + err := invite.EncryptSensitiveData(nil) + require.NoError(t, err) + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + + // Decrypt with nil encryptor should be no-op + err = invite.DecryptSensitiveData(nil) + require.NoError(t, err) + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + }) +} diff --git a/shared/auth/jwt/validator.go b/shared/auth/jwt/validator.go index ede7acea5..aeaa5842c 100644 --- a/shared/auth/jwt/validator.go +++ b/shared/auth/jwt/validator.go @@ -72,8 +72,8 @@ var ( func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator { keys, err := getPemKeys(keysLocation) - if err != nil { - log.WithField("keysLocation", keysLocation).Errorf("could not get keys from location: %s", err) + if err != nil && !strings.Contains(keysLocation, "localhost") { + log.WithField("keysLocation", keysLocation).Warnf("could not get keys from location: %s, it will try again on the next http request", err) } return &Validator{ diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index f1ff98b16..26d2387d1 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -488,6 +488,171 @@ components: - role - auto_groups - is_service_user + UserInviteCreateRequest: + type: object + description: Request to create a user invite link + properties: + email: + description: User's email address + type: string + example: user@example.com + name: + description: User's full name + type: string + example: John Doe + role: + description: User's NetBird account role + type: string + example: user + auto_groups: + description: Group IDs to auto-assign to peers registered by this user + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m0 + expires_in: + description: Invite expiration time in seconds (default 72 hours) + type: integer + example: 259200 + required: + - email + - name + - role + - auto_groups + UserInvite: + type: object + description: A user invite + properties: + id: + description: Invite ID + type: string + example: d5p7eedra0h0lt6f59hg + email: + description: User's email address + type: string + example: user@example.com + name: + description: User's full name + type: string + example: John Doe + role: + description: User's NetBird account role + type: string + example: user + auto_groups: + description: Group IDs to auto-assign to peers registered by this user + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m0 + expires_at: + description: Invite expiration time + type: string + format: date-time + example: "2024-01-25T10:00:00Z" + created_at: + description: Invite creation time + type: string + format: date-time + example: "2024-01-22T10:00:00Z" + expired: + description: Whether the invite has expired + type: boolean + example: false + invite_token: + description: The invite link to be shared with the user. Only returned when the invite is created or regenerated. + type: string + example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123 + required: + - id + - email + - name + - role + - auto_groups + - expires_at + - created_at + - expired + UserInviteInfo: + type: object + description: Public information about an invite + properties: + email: + description: User's email address + type: string + example: user@example.com + name: + description: User's full name + type: string + example: John Doe + expires_at: + description: Invite expiration time + type: string + format: date-time + example: "2024-01-25T10:00:00Z" + valid: + description: Whether the invite is still valid (not expired) + type: boolean + example: true + invited_by: + description: Name of the user who sent the invite + type: string + example: Admin User + required: + - email + - name + - expires_at + - valid + - invited_by + UserInviteAcceptRequest: + type: object + description: Request to accept an invite and set password + properties: + password: + description: >- + The password the user wants to set. Must be at least 8 characters long + and contain at least one uppercase letter, one digit, and one special + character (any character that is not a letter or digit, including spaces). + type: string + format: password + minLength: 8 + pattern: '^(?=.*[0-9])(?=.*[A-Z])(?=.*[^a-zA-Z0-9]).{8,}$' + example: SecurePass123! + required: + - password + UserInviteAcceptResponse: + type: object + description: Response after accepting an invite + properties: + success: + description: Whether the invite was accepted successfully + type: boolean + example: true + required: + - success + UserInviteRegenerateRequest: + type: object + description: Request to regenerate an invite link + properties: + expires_in: + description: Invite expiration time in seconds (default 72 hours) + type: integer + example: 259200 + UserInviteRegenerateResponse: + type: object + description: Response after regenerating an invite + properties: + invite_token: + description: The new invite token + type: string + example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123 + invite_expires_at: + description: New invite expiration time + type: string + format: date-time + example: "2024-01-28T10:00:00Z" + required: + - invite_token + - invite_expires_at PeerMinimum: type: object properties: @@ -2071,7 +2236,8 @@ components: "dns.zone.create", "dns.zone.update", "dns.zone.delete", "dns.zone.record.create", "dns.zone.record.update", "dns.zone.record.delete", "peer.job.create", - "user.password.change" + "user.password.change", + "user.invite.link.create", "user.invite.link.accept", "user.invite.link.regenerate", "user.invite.link.delete" ] example: route.add initiator_id: @@ -2642,6 +2808,29 @@ components: required: - user_id - email + InstanceVersionInfo: + type: object + description: Version information for NetBird components + properties: + management_current_version: + description: The current running version of the management server + type: string + example: "0.35.0" + dashboard_available_version: + description: The latest available version of the dashboard (from GitHub releases) + type: string + example: "2.10.0" + management_available_version: + description: The latest available version of the management server (from GitHub releases) + type: string + example: "0.35.0" + management_update_available: + description: Indicates if a newer management version is available + type: boolean + example: true + required: + - management_current_version + - management_update_available responses: not_found: description: Resource not found @@ -2694,6 +2883,27 @@ paths: $ref: '#/components/schemas/InstanceStatus' '500': "$ref": "#/components/responses/internal_error" + /api/instance/version: + get: + summary: Get Version Info + description: Returns version information for NetBird components including the current management server version and latest available versions from GitHub. + tags: [ Instance ] + security: + - BearerAuth: [] + - TokenAuth: [] + responses: + '200': + description: Version information + content: + application/json: + schema: + $ref: '#/components/schemas/InstanceVersionInfo' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/setup: post: summary: Setup Instance @@ -3312,6 +3522,210 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/users/invites: + get: + summary: List user invites + description: Lists all pending invites for the account. Only available when embedded IdP is enabled. + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: List of invites + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/UserInvite' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '412': + description: Precondition failed - embedded IdP is not enabled + content: { } + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a user invite + description: Creates an invite link for a new user. Only available when embedded IdP is enabled. The user is not created until they accept the invite. + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: User invite information + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteCreateRequest' + responses: + '200': + description: Invite created successfully + content: + application/json: + schema: + $ref: '#/components/schemas/UserInvite' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '409': + description: User or invite already exists + content: { } + '412': + description: Precondition failed - embedded IdP is not enabled + content: { } + '422': + "$ref": "#/components/responses/validation_failed" + '500': + "$ref": "#/components/responses/internal_error" + /api/users/invites/{inviteId}: + delete: + summary: Delete a user invite + description: Deletes a pending invite. Only available when embedded IdP is enabled. + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: inviteId + required: true + schema: + type: string + description: The ID of the invite to delete + responses: + '200': + description: Invite deleted successfully + content: { } + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + description: Invite not found + content: { } + '412': + description: Precondition failed - embedded IdP is not enabled + content: { } + '500': + "$ref": "#/components/responses/internal_error" + /api/users/invites/{inviteId}/regenerate: + post: + summary: Regenerate a user invite + description: Regenerates an invite link for an existing invite. Invalidates the previous token and creates a new one. + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: inviteId + required: true + schema: + type: string + description: The ID of the invite to regenerate + requestBody: + description: Regenerate options + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteRegenerateRequest' + responses: + '200': + description: Invite regenerated successfully + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteRegenerateResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + description: Invite not found + content: { } + '412': + description: Precondition failed - embedded IdP is not enabled + content: { } + '422': + "$ref": "#/components/responses/validation_failed" + '500': + "$ref": "#/components/responses/internal_error" + /api/users/invites/{token}: + get: + summary: Get invite information + description: Retrieves public information about an invite. This endpoint is unauthenticated and protected by the token itself. + tags: [ Users ] + security: [] + parameters: + - in: path + name: token + required: true + schema: + type: string + description: The invite token + responses: + '200': + description: Invite information + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteInfo' + '400': + "$ref": "#/components/responses/bad_request" + '404': + description: Invite not found or invalid token + content: { } + '500': + "$ref": "#/components/responses/internal_error" + /api/users/invites/{token}/accept: + post: + summary: Accept an invite + description: Accepts an invite and creates the user with the provided password. This endpoint is unauthenticated and protected by the token itself. + tags: [ Users ] + security: [] + parameters: + - in: path + name: token + required: true + schema: + type: string + description: The invite token + requestBody: + description: Password to set for the new user + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteAcceptRequest' + responses: + '200': + description: Invite accepted successfully + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteAcceptResponse' + '400': + "$ref": "#/components/responses/bad_request" + '404': + description: Invite not found or invalid token + content: { } + '412': + description: Precondition failed - embedded IdP is not enabled or invite expired + content: { } + '422': + "$ref": "#/components/responses/validation_failed" + '500': + "$ref": "#/components/responses/internal_error" /api/peers: get: summary: List all Peers diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 848023689..e8c044b32 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -123,6 +123,10 @@ const ( EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add" EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete" EventActivityCodeUserInvite EventActivityCode = "user.invite" + EventActivityCodeUserInviteLinkAccept EventActivityCode = "user.invite.link.accept" + EventActivityCodeUserInviteLinkCreate EventActivityCode = "user.invite.link.create" + EventActivityCodeUserInviteLinkDelete EventActivityCode = "user.invite.link.delete" + EventActivityCodeUserInviteLinkRegenerate EventActivityCode = "user.invite.link.regenerate" EventActivityCodeUserJoin EventActivityCode = "user.join" EventActivityCodeUserPasswordChange EventActivityCode = "user.password.change" EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete" @@ -870,6 +874,21 @@ type InstanceStatus struct { SetupRequired bool `json:"setup_required"` } +// InstanceVersionInfo Version information for NetBird components +type InstanceVersionInfo struct { + // DashboardAvailableVersion The latest available version of the dashboard (from GitHub releases) + DashboardAvailableVersion *string `json:"dashboard_available_version,omitempty"` + + // ManagementAvailableVersion The latest available version of the management server (from GitHub releases) + ManagementAvailableVersion *string `json:"management_available_version,omitempty"` + + // ManagementCurrentVersion The current running version of the management server + ManagementCurrentVersion string `json:"management_current_version"` + + // ManagementUpdateAvailable Indicates if a newer management version is available + ManagementUpdateAvailable bool `json:"management_update_available"` +} + // JobRequest defines model for JobRequest. type JobRequest struct { Workload WorkloadRequest `json:"workload"` @@ -2166,6 +2185,99 @@ type UserCreateRequest struct { Role string `json:"role"` } +// UserInvite A user invite +type UserInvite struct { + // AutoGroups Group IDs to auto-assign to peers registered by this user + AutoGroups []string `json:"auto_groups"` + + // CreatedAt Invite creation time + CreatedAt time.Time `json:"created_at"` + + // Email User's email address + Email string `json:"email"` + + // Expired Whether the invite has expired + Expired bool `json:"expired"` + + // ExpiresAt Invite expiration time + ExpiresAt time.Time `json:"expires_at"` + + // Id Invite ID + Id string `json:"id"` + + // InviteToken The invite link to be shared with the user. Only returned when the invite is created or regenerated. + InviteToken *string `json:"invite_token,omitempty"` + + // Name User's full name + Name string `json:"name"` + + // Role User's NetBird account role + Role string `json:"role"` +} + +// UserInviteAcceptRequest Request to accept an invite and set password +type UserInviteAcceptRequest struct { + // Password The password the user wants to set. Must be at least 8 characters long and contain at least one uppercase letter, one digit, and one special character (any character that is not a letter or digit, including spaces). + Password string `json:"password"` +} + +// UserInviteAcceptResponse Response after accepting an invite +type UserInviteAcceptResponse struct { + // Success Whether the invite was accepted successfully + Success bool `json:"success"` +} + +// UserInviteCreateRequest Request to create a user invite link +type UserInviteCreateRequest struct { + // AutoGroups Group IDs to auto-assign to peers registered by this user + AutoGroups []string `json:"auto_groups"` + + // Email User's email address + Email string `json:"email"` + + // ExpiresIn Invite expiration time in seconds (default 72 hours) + ExpiresIn *int `json:"expires_in,omitempty"` + + // Name User's full name + Name string `json:"name"` + + // Role User's NetBird account role + Role string `json:"role"` +} + +// UserInviteInfo Public information about an invite +type UserInviteInfo struct { + // Email User's email address + Email string `json:"email"` + + // ExpiresAt Invite expiration time + ExpiresAt time.Time `json:"expires_at"` + + // InvitedBy Name of the user who sent the invite + InvitedBy string `json:"invited_by"` + + // Name User's full name + Name string `json:"name"` + + // Valid Whether the invite is still valid (not expired) + Valid bool `json:"valid"` +} + +// UserInviteRegenerateRequest Request to regenerate an invite link +type UserInviteRegenerateRequest struct { + // ExpiresIn Invite expiration time in seconds (default 72 hours) + ExpiresIn *int `json:"expires_in,omitempty"` +} + +// UserInviteRegenerateResponse Response after regenerating an invite +type UserInviteRegenerateResponse struct { + // InviteExpiresAt New invite expiration time + InviteExpiresAt time.Time `json:"invite_expires_at"` + + // InviteToken The new invite token + InviteToken string `json:"invite_token"` +} + // UserPermissions defines model for UserPermissions. type UserPermissions struct { // IsRestricted Indicates whether this User's Peers view is restricted @@ -2418,6 +2530,15 @@ type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest // PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType. type PostApiUsersJSONRequestBody = UserCreateRequest +// PostApiUsersInvitesJSONRequestBody defines body for PostApiUsersInvites for application/json ContentType. +type PostApiUsersInvitesJSONRequestBody = UserInviteCreateRequest + +// PostApiUsersInvitesInviteIdRegenerateJSONRequestBody defines body for PostApiUsersInvitesInviteIdRegenerate for application/json ContentType. +type PostApiUsersInvitesInviteIdRegenerateJSONRequestBody = UserInviteRegenerateRequest + +// PostApiUsersInvitesTokenAcceptJSONRequestBody defines body for PostApiUsersInvitesTokenAccept for application/json ContentType. +type PostApiUsersInvitesTokenAcceptJSONRequestBody = UserInviteAcceptRequest + // PutApiUsersUserIdJSONRequestBody defines body for PutApiUsersUserId for application/json ContentType. type PutApiUsersUserIdJSONRequestBody = UserRequest diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index bc2d4d1be..523beb32b 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -154,9 +154,20 @@ func (s *SharedSocket) updateRouter() { } } -// LocalAddr returns an IPv4 address using the supplied port +// LocalAddr returns the local address, preferring IPv4 for backward compatibility. func (s *SharedSocket) LocalAddr() net.Addr { - // todo check impact on ipv6 discovery + if s.conn4 != nil { + return &net.UDPAddr{ + IP: net.IPv4zero, + Port: s.port, + } + } + if s.conn6 != nil { + return &net.UDPAddr{ + IP: net.IPv6zero, + Port: s.port, + } + } return &net.UDPAddr{ IP: net.IPv4zero, Port: s.port, diff --git a/util/crypt/crypt_test.go b/util/crypt/crypt_test.go new file mode 100644 index 000000000..143a4bbc2 --- /dev/null +++ b/util/crypt/crypt_test.go @@ -0,0 +1,139 @@ +package crypt + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateKey(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + assert.NotEmpty(t, key) + + _, err = NewFieldEncrypt(key) + assert.NoError(t, err) +} + +func TestNewFieldEncrypt_InvalidKey(t *testing.T) { + tests := []struct { + name string + key string + }{ + {name: "invalid base64", key: "not-valid-base64!!!"}, + {name: "too short", key: "c2hvcnQ="}, + {name: "empty", key: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewFieldEncrypt(tt.key) + assert.Error(t, err) + }) + } +} + +func TestEncryptDecrypt(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + testCases := []struct { + name string + input string + }{ + {name: "Empty String", input: ""}, + {name: "Short String", input: "Hello"}, + {name: "String with Spaces", input: "Hello, World!"}, + {name: "Long String", input: "The quick brown fox jumps over the lazy dog."}, + {name: "Unicode Characters", input: "こんにちは世界"}, + {name: "Special Characters", input: "!@#$%^&*()_+-=[]{}|;':\",./<>?"}, + {name: "Numeric String", input: "1234567890"}, + {name: "Email Address", input: "user@example.com"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encrypted, err := ec.Encrypt(tc.input) + require.NoError(t, err) + + decrypted, err := ec.Decrypt(encrypted) + require.NoError(t, err) + + assert.Equal(t, tc.input, decrypted) + }) + } +} + +func TestEncrypt_DifferentCiphertexts(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + plaintext := "same plaintext" + + // Encrypt the same plaintext multiple times + encrypted1, err := ec.Encrypt(plaintext) + require.NoError(t, err) + + encrypted2, err := ec.Encrypt(plaintext) + require.NoError(t, err) + + assert.NotEqual(t, encrypted1, encrypted2, "expected different ciphertexts for same plaintext (random nonce)") + + // Both should decrypt to the same plaintext + decrypted1, err := ec.Decrypt(encrypted1) + require.NoError(t, err) + + decrypted2, err := ec.Decrypt(encrypted2) + require.NoError(t, err) + + assert.Equal(t, plaintext, decrypted1) + assert.Equal(t, plaintext, decrypted2) +} + +func TestDecrypt_InvalidCiphertext(t *testing.T) { + key, err := GenerateKey() + assert.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + assert.NoError(t, err) + + tests := []struct { + name string + ciphertext string + }{ + {name: "invalid base64", ciphertext: "not-valid!!!"}, + {name: "too short", ciphertext: "c2hvcnQ="}, + {name: "corrupted", ciphertext: "YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXo="}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, err := ec.Decrypt(tt.ciphertext) + assert.Error(t, err) + assert.Empty(t, payload) + }) + } +} + +func TestDecrypt_WrongKey(t *testing.T) { + key1, _ := GenerateKey() + key2, _ := GenerateKey() + + ec1, _ := NewFieldEncrypt(key1) + ec2, _ := NewFieldEncrypt(key2) + + plaintext := "secret data" + encrypted, _ := ec1.Encrypt(plaintext) + + // Try to decrypt with wrong key + payload, err := ec2.Decrypt(encrypted) + assert.Error(t, err) + assert.Empty(t, payload) +} diff --git a/util/crypt/legacy.go b/util/crypt/legacy.go new file mode 100644 index 000000000..f84e6964f --- /dev/null +++ b/util/crypt/legacy.go @@ -0,0 +1,71 @@ +package crypt + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "fmt" +) + +// legacyIV is the static IV used by the legacy CBC encryption. +// Deprecated: This is kept only for backward compatibility with existing encrypted data. +var legacyIV = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} + +// LegacyEncrypt encrypts plaintext using AES-CBC with a static IV. +// Deprecated: Use Encrypt instead. This method is kept only for backward compatibility. +func (f *FieldEncrypt) LegacyEncrypt(plaintext string) string { + padded := pkcs5Padding([]byte(plaintext)) + ciphertext := make([]byte, len(padded)) + cbc := cipher.NewCBCEncrypter(f.block, legacyIV) + cbc.CryptBlocks(ciphertext, padded) + return base64.StdEncoding.EncodeToString(ciphertext) +} + +// LegacyDecrypt decrypts ciphertext that was encrypted using AES-CBC with a static IV. +// Deprecated: This method is kept only for backward compatibility with existing encrypted data. +func (f *FieldEncrypt) LegacyDecrypt(ciphertext string) (string, error) { + data, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return "", fmt.Errorf("decode ciphertext: %w", err) + } + + cbc := cipher.NewCBCDecrypter(f.block, legacyIV) + cbc.CryptBlocks(data, data) + + plaintext, err := pkcs5UnPadding(data) + if err != nil { + return "", fmt.Errorf("unpad plaintext: %w", err) + } + + return string(plaintext), nil +} + +// pkcs5Padding adds PKCS#5 padding to the input. +func pkcs5Padding(data []byte) []byte { + padding := aes.BlockSize - len(data)%aes.BlockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(data, padText...) +} + +// pkcs5UnPadding removes PKCS#5 padding from the input. +func pkcs5UnPadding(data []byte) ([]byte, error) { + length := len(data) + if length == 0 { + return nil, fmt.Errorf("input data is empty") + } + + paddingLen := int(data[length-1]) + if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > length { + return nil, fmt.Errorf("invalid padding size") + } + + // Verify that all padding bytes are the same + for i := 0; i < paddingLen; i++ { + if data[length-1-i] != byte(paddingLen) { + return nil, fmt.Errorf("invalid padding") + } + } + + return data[:length-paddingLen], nil +} diff --git a/util/crypt/legacy_test.go b/util/crypt/legacy_test.go new file mode 100644 index 000000000..09b75a71f --- /dev/null +++ b/util/crypt/legacy_test.go @@ -0,0 +1,164 @@ +package crypt + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLegacyEncryptDecrypt(t *testing.T) { + testData := "exampl@netbird.io" + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + encrypted := ec.LegacyEncrypt(testData) + assert.NotEmpty(t, encrypted) + + decrypted, err := ec.LegacyDecrypt(encrypted) + require.NoError(t, err) + + assert.Equal(t, testData, decrypted) +} + +func TestLegacyEncryptDecryptVariousInputs(t *testing.T) { + key, err := GenerateKey() + require.NoError(t, err) + + ec, err := NewFieldEncrypt(key) + require.NoError(t, err) + + testCases := []struct { + name string + input string + }{ + {name: "Empty String", input: ""}, + {name: "Short String", input: "Hello"}, + {name: "String with Spaces", input: "Hello, World!"}, + {name: "Long String", input: "The quick brown fox jumps over the lazy dog."}, + {name: "Unicode Characters", input: "こんにちは世界"}, + {name: "Special Characters", input: "!@#$%^&*()_+-=[]{}|;':\",./<>?"}, + {name: "Numeric String", input: "1234567890"}, + {name: "Repeated Characters", input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}, + {name: "Multi-block String", input: "This is a longer string that will span multiple blocks in the encryption algorithm."}, + {name: "Non-ASCII and ASCII Mix", input: "Hello 世界 123"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encrypted := ec.LegacyEncrypt(tc.input) + assert.NotEmpty(t, encrypted) + + decrypted, err := ec.LegacyDecrypt(encrypted) + require.NoError(t, err) + + assert.Equal(t, tc.input, decrypted) + }) + } +} + +func TestPKCS5UnPadding(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + expectError bool + }{ + { + name: "Valid Padding", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...), + expected: []byte("Hello, World!"), + }, + { + name: "Empty Input", + input: []byte{}, + expectError: true, + }, + { + name: "Padding Length Zero", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...), + expectError: true, + }, + { + name: "Padding Length Exceeds Block Size", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...), + expectError: true, + }, + { + name: "Padding Length Exceeds Input Length", + input: []byte{5, 5, 5}, + expectError: true, + }, + { + name: "Invalid Padding Bytes", + input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...), + expectError: true, + }, + { + name: "Valid Single Byte Padding", + input: append([]byte("Hello, World!"), byte(1)), + expected: []byte("Hello, World!"), + }, + { + name: "Invalid Mixed Padding Bytes", + input: append([]byte("Hello, World!"), []byte{3, 3, 2}...), + expectError: true, + }, + { + name: "Valid Full Block Padding", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...), + expected: []byte("Hello, World!"), + }, + { + name: "Non-Padding Byte at End", + input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...), + expectError: true, + }, + { + name: "Valid Padding with Different Text Length", + input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...), + expected: []byte("Test"), + }, + { + name: "Padding Length Equal to Input Length", + input: bytes.Repeat([]byte{8}, 8), + expected: []byte{}, + }, + { + name: "Invalid Padding Length Zero (Again)", + input: append([]byte("Test"), byte(0)), + expectError: true, + }, + { + name: "Padding Length Greater Than Input", + input: []byte{10}, + expectError: true, + }, + { + name: "Input Length Not Multiple of Block Size", + input: append([]byte("Invalid Length"), byte(1)), + expected: []byte("Invalid Length"), + }, + { + name: "Valid Padding with Non-ASCII Characters", + input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...), + expected: []byte("こんにちは"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := pkcs5UnPadding(tt.input) + if tt.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +}