Compare commits

..

9 Commits

Author SHA1 Message Date
pascal
5c049f6f09 delete tests 2026-02-24 16:53:49 +01:00
pascal
740c726a78 extract proxy controller 2026-02-23 15:28:20 +01:00
pascal
3af287ebab move service manager 2026-02-20 01:21:05 +01:00
pascal
d4d885d434 Merge remote-tracking branch 'origin/main' into feature/add-serial-to-proxy 2026-02-20 00:35:10 +01:00
pascal
d212332f5d store proxies in DB 2026-02-20 00:28:45 +01:00
pascal
0e11258e97 fix test 2026-02-19 13:30:26 +01:00
pascal
31ecf8f1f5 allow redis for token store 2026-02-19 13:28:36 +01:00
pascal
e2df1fb35e add accountID when sending update to cluster 2026-02-19 12:02:31 +01:00
pascal
942cd5dc72 export methods 2026-02-19 10:50:32 +01:00
127 changed files with 1161 additions and 14319 deletions

View File

@@ -1,14 +0,0 @@
blank_issues_enabled: true
contact_links:
- name: Community Support
url: https://forum.netbird.io/
about: Community support forum
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact us for support
- name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.3
FROM alpine:3.23.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \

View File

@@ -1,194 +0,0 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
var (
exposePin string
exposePassword string
exposeUserGroups []string
exposeDomain string
exposeNamePrefix string
exposeProtocol string
)
var exposeCmd = &cobra.Command{
Use: "expose <port>",
Short: "Expose a local port via the NetBird reverse proxy",
Args: cobra.ExactArgs(1),
Example: "netbird expose --with-password safe-pass 8080",
RunE: exposeFn,
}
func init() {
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use, http/https is supported (e.g. --protocol http)")
}
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
port, err := strconv.ParseUint(portStr, 10, 32)
if err != nil {
return 0, fmt.Errorf("invalid port number: %s", portStr)
}
if port == 0 || port > 65535 {
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
}
if !isProtocolValid(exposeProtocol) {
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
}
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
}
if cmd.Flags().Changed("with-password") && exposePassword == "" {
return 0, fmt.Errorf("password cannot be empty")
}
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
return 0, fmt.Errorf("user groups cannot be empty")
}
return port, nil
}
func isProtocolValid(exposeProtocol string) bool {
return strings.ToLower(exposeProtocol) == "http" || strings.ToLower(exposeProtocol) == "https"
}
func exposeFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
log.Errorf("failed initializing log %v", err)
return err
}
cmd.Root().SilenceUsage = false
port, err := validateExposeFlags(cmd, args[0])
if err != nil {
return err
}
cmd.Root().SilenceUsage = true
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close daemon connection: %v", err)
}
}()
client := proto.NewDaemonServiceClient(conn)
protocol, err := toExposeProtocol(exposeProtocol)
if err != nil {
return err
}
stream, err := client.ExposeService(ctx, &proto.ExposeServiceRequest{
Port: uint32(port),
Protocol: protocol,
Pin: exposePin,
Password: exposePassword,
UserGroups: exposeUserGroups,
Domain: exposeDomain,
NamePrefix: exposeNamePrefix,
})
if err != nil {
return fmt.Errorf("expose service: %w", err)
}
if err := handleExposeReady(cmd, stream, port); err != nil {
return err
}
return waitForExposeEvents(cmd, ctx, stream)
}
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
switch strings.ToLower(exposeProtocol) {
case "http":
return proto.ExposeProtocol_EXPOSE_HTTP, nil
case "https":
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
default:
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
}
}
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %w", err)
}
switch e := event.Event.(type) {
case *proto.ExposeServiceEvent_Ready:
cmd.Println("Service exposed successfully!")
cmd.Printf(" Name: %s\n", e.Ready.ServiceName)
cmd.Printf(" URL: %s\n", e.Ready.ServiceUrl)
cmd.Printf(" Domain: %s\n", e.Ready.Domain)
cmd.Printf(" Protocol: %s\n", exposeProtocol)
cmd.Printf(" Port: %d\n", port)
cmd.Println()
cmd.Println("Press Ctrl+C to stop exposing.")
return nil
default:
return fmt.Errorf("unexpected expose event: %T", event.Event)
}
}
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
for {
_, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.Println("\nService stopped.")
//nolint:nilerr
return nil
}
if errors.Is(err, io.EOF) {
return fmt.Errorf("connection to daemon closed unexpectedly")
}
return fmt.Errorf("stream error: %w", err)
}
}
}

View File

@@ -22,7 +22,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -81,15 +80,6 @@ var (
Short: "",
Long: "",
SilenceUsage: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(cmd.Root())
// Don't resolve for service commands — they create the socket, not connect to it.
if !isServiceCmd(cmd) {
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
}
return nil
},
}
)
@@ -154,7 +144,6 @@ func init() {
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
rootCmd.AddCommand(profileCmd)
rootCmd.AddCommand(exposeCmd)
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -396,6 +385,7 @@ func migrateToNetbird(oldPath, newPath string) bool {
}
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
@@ -408,13 +398,3 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
return conn, nil
}
// isServiceCmd returns true if cmd is the "service" command or a child of it.
func isServiceCmd(cmd *cobra.Command) bool {
for c := cmd; c != nil; c = c.Parent() {
if c.Name() == "service" {
return true
}
}
return false
}

View File

@@ -5,18 +5,20 @@ package configurer
import (
"net"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/ipc"
)
func openUAPI(deviceName string) (net.Listener, error) {
uapiSock, err := ipc.UAPIOpen(deviceName)
if err != nil {
log.Errorf("failed to open uapi socket: %v", err)
return nil, err
}
listener, err := ipc.UAPIListen(deviceName, uapiSock)
if err != nil {
_ = uapiSock.Close()
log.Errorf("failed to listen on uapi socket: %v", err)
return nil, err
}

View File

@@ -54,14 +54,6 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
return wgCfg
}
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
return &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
}
}
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)

View File

@@ -79,7 +79,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
if cErr := tunIface.Close(); cErr != nil {

View File

@@ -331,11 +331,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
state.Set(StatusConnected)
if runningChan != nil {
select {
case <-runningChan:
default:
close(runningChan)
}
close(runningChan)
runningChan = nil
}
<-engineCtx.Done()

View File

@@ -1,60 +0,0 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
var scanDir = "/var/run/netbird"
// setScanDir overrides the scan directory (used by tests).
func setScanDir(dir string) {
scanDir = dir
}
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
// mismatch between the netbird@.service template (which places the socket under
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
func ResolveUnixDaemonAddr(addr string) string {
if !strings.HasPrefix(addr, "unix://") {
return addr
}
sockPath := strings.TrimPrefix(addr, "unix://")
if _, err := os.Stat(sockPath); err == nil {
return addr
}
entries, err := os.ReadDir(scanDir)
if err != nil {
return addr
}
var found []string
for _, e := range entries {
if e.IsDir() {
continue
}
if strings.HasSuffix(e.Name(), ".sock") {
found = append(found, filepath.Join(scanDir, e.Name()))
}
}
switch len(found) {
case 1:
resolved := "unix://" + found[0]
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
return resolved
case 0:
return addr
default:
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
return addr
}
}

View File

@@ -1,8 +0,0 @@
//go:build windows || ios || android
package daemonaddr
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
func ResolveUnixDaemonAddr(addr string) string {
return addr
}

View File

@@ -1,121 +0,0 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"testing"
)
// createSockFile creates a regular file with a .sock extension.
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
// sufficient and avoids Unix socket path-length limits on macOS.
func createSockFile(t *testing.T, path string) {
t.Helper()
if err := os.WriteFile(path, nil, 0o600); err != nil {
t.Fatalf("failed to create test sock file at %s: %v", path, err)
}
}
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
tmp := t.TempDir()
sock := filepath.Join(tmp, "netbird.sock")
createSockFile(t, sock)
addr := "unix://" + sock
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
tmp := t.TempDir()
// Default socket does not exist
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
// Create a scan dir with one socket
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
instanceSock := filepath.Join(sd, "main.sock")
createSockFile(t, instanceSock)
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
expected := "unix://" + instanceSock
if got != expected {
t.Errorf("expected %s, got %s", expected, got)
}
}
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
createSockFile(t, filepath.Join(sd, "main.sock"))
createSockFile(t, filepath.Join(sd, "other.sock"))
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
addr := "tcp://127.0.0.1:41731"
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
origScanDir := scanDir
setScanDir(filepath.Join(tmp, "nonexistent"))
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}

View File

@@ -277,7 +277,7 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
}
}
log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains)
return ruleIndex, nil
}

View File

@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
}
}
// Flow receiver domain is intentionally excluded from caching.
// Cloud providers may rotate the IP behind this domain; a stale cached record
// causes TLS certificate verification failures on reconnect.
if serverDomains.Flow != "" {
domains = append(domains, serverDomains.Flow)
}
for _, stun := range serverDomains.Stuns {
if stun != "" {

View File

@@ -391,8 +391,7 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
}
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
partialDomains := dnsconfig.ServerDomains{
Flow: "github.com",
}
@@ -401,10 +400,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains {
@@ -413,5 +412,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
assert.Contains(t, domainStrings, "example.org")
assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com")
assert.NotContains(t, domainStrings, "github.com")
assert.Contains(t, domainStrings, "github.com")
}

View File

@@ -351,13 +351,9 @@ func (u *upstreamResolverBase) waitUntilResponse() {
return fmt.Errorf("upstream check call error")
}
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
err := backoff.Retry(operation, exponentialBackOff)
if err != nil {
if errors.Is(err, context.Canceled) {
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
} else {
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
}
log.Warn(err)
return
}

View File

@@ -36,7 +36,6 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
@@ -54,11 +53,13 @@ import (
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -74,6 +75,7 @@ import (
const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
disableAutoUpdate = "disabled"
)
@@ -206,6 +208,7 @@ type Engine struct {
syncRespMux sync.RWMutex
persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
// auto-update
@@ -221,8 +224,6 @@ type Engine struct {
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
exposeManager *expose.Manager
}
// Peer is an instance of the Connection Peer
@@ -265,6 +266,7 @@ func NewEngine(
statusRecorder: statusRecorder,
stateManager: stateManager,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
}
@@ -417,7 +419,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.cancel()
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
wgIface, err := e.newWgIface()
if err != nil {
@@ -800,7 +801,7 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
disabled := autoUpdateSettings.Version == disableAutoUpdate
// stop and cleanup if disabled
// Stop and cleanup if disabled
if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop()
@@ -1538,6 +1539,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1822,18 +1824,11 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager
}
// GetFirewallManager returns the firewall manager.
// GetFirewallManager returns the firewall manager
func (e *Engine) GetFirewallManager() firewallManager.Manager {
return e.firewall
}
// GetExposeManager returns the expose session manager.
func (e *Engine) GetExposeManager() *expose.Manager {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return e.exposeManager
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -1,95 +0,0 @@
package expose
import (
"context"
"time"
mgm "github.com/netbirdio/netbird/shared/management/client"
log "github.com/sirupsen/logrus"
)
const renewTimeout = 10 * time.Second
// Response holds the response from exposing a service.
type Response struct {
ServiceName string
ServiceURL string
Domain string
}
type Request struct {
NamePrefix string
Domain string
Port uint16
Protocol int
Pin string
Password string
UserGroups []string
}
type ManagementClient interface {
CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error)
RenewExpose(ctx context.Context, domain string) error
StopExpose(ctx context.Context, domain string) error
}
// Manager handles expose session lifecycle via the management client.
type Manager struct {
mgmClient ManagementClient
ctx context.Context
}
// NewManager creates a new expose Manager using the given management client.
func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager {
return &Manager{mgmClient: mgmClient, ctx: ctx}
}
// Expose creates a new expose session via the management server.
func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
log.Infof("exposing service on port %d", req.Port)
resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req))
if err != nil {
return nil, err
}
log.Infof("expose session created for %s", resp.Domain)
return fromClientExposeResponse(resp), nil
}
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
defer m.stop(domain)
for {
select {
case <-ctx.Done():
log.Infof("context canceled, stopping keep alive for %s", domain)
return nil
case <-ticker.C:
if err := m.renew(ctx, domain); err != nil {
log.Errorf("renewing expose session for %s: %v", domain, err)
return err
}
}
}
}
// renew extends the TTL of an active expose session.
func (m *Manager) renew(ctx context.Context, domain string) error {
renewCtx, cancel := context.WithTimeout(ctx, renewTimeout)
defer cancel()
return m.mgmClient.RenewExpose(renewCtx, domain)
}
// stop terminates an active expose session.
func (m *Manager) stop(domain string) {
stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout)
defer cancel()
err := m.mgmClient.StopExpose(stopCtx, domain)
if err != nil {
log.Warnf("Failed stopping expose session for %s: %v", domain, err)
}
}

View File

@@ -1,95 +0,0 @@
package expose
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
daemonProto "github.com/netbirdio/netbird/client/proto"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
func TestManager_Expose_Success(t *testing.T) {
mock := &mgm.MockClient{
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
return &mgm.ExposeResponse{
ServiceName: "my-service",
ServiceURL: "https://my-service.example.com",
Domain: "my-service.example.com",
}, nil
},
}
m := NewManager(context.Background(), mock)
result, err := m.Expose(context.Background(), Request{Port: 8080})
require.NoError(t, err)
assert.Equal(t, "my-service", result.ServiceName, "service name should match")
assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match")
assert.Equal(t, "my-service.example.com", result.Domain, "domain should match")
}
func TestManager_Expose_Error(t *testing.T) {
mock := &mgm.MockClient{
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
return nil, errors.New("permission denied")
},
}
m := NewManager(context.Background(), mock)
_, err := m.Expose(context.Background(), Request{Port: 8080})
require.Error(t, err)
assert.Contains(t, err.Error(), "permission denied", "error should propagate")
}
func TestManager_Renew_Success(t *testing.T) {
mock := &mgm.MockClient{
RenewExposeFunc: func(ctx context.Context, domain string) error {
assert.Equal(t, "my-service.example.com", domain, "domain should be passed through")
return nil
},
}
m := NewManager(context.Background(), mock)
err := m.renew(context.Background(), "my-service.example.com")
require.NoError(t, err)
}
func TestManager_Renew_Timeout(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
mock := &mgm.MockClient{
RenewExposeFunc: func(ctx context.Context, domain string) error {
return ctx.Err()
},
}
m := NewManager(ctx, mock)
err := m.renew(ctx, "my-service.example.com")
require.Error(t, err)
}
func TestNewRequest(t *testing.T) {
req := &daemonProto.ExposeServiceRequest{
Port: 8080,
Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS,
Pin: "123456",
Password: "secret",
UserGroups: []string{"group1", "group2"},
Domain: "custom.example.com",
NamePrefix: "my-prefix",
}
exposeReq := NewRequest(req)
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
assert.Equal(t, "secret", exposeReq.Password, "password should match")
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match")
assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match")
}

View File

@@ -1,39 +0,0 @@
package expose
import (
daemonProto "github.com/netbirdio/netbird/client/proto"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest.
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
return &Request{
Port: uint16(req.Port),
Protocol: int(req.Protocol),
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
Domain: req.Domain,
NamePrefix: req.NamePrefix,
}
}
func toClientExposeRequest(req Request) mgm.ExposeRequest {
return mgm.ExposeRequest{
NamePrefix: req.NamePrefix,
Domain: req.Domain,
Port: req.Port,
Protocol: req.Protocol,
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
}
}
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
return &Response{
ServiceName: response.ServiceName,
Domain: response.Domain,
ServiceURL: response.ServiceURL,
}
}

View File

@@ -22,56 +22,51 @@ func prepareFd() (int, error) {
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
for {
// Wait until fd is readable or context is cancelled, to avoid a busy-loop
// when the routing socket returns EAGAIN (e.g. immediately after wakeup).
if err := waitReadable(ctx, fd); err != nil {
return err
}
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
continue
}
if errors.Is(err, unix.EBADF) || errors.Is(err, unix.EINVAL) {
return fmt.Errorf("routing socket closed: %w", err)
}
return fmt.Errorf("read routing socket: %w", err)
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
select {
case <-ctx.Done():
return ctx.Err()
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
if route.Dst.Bits() != 0 {
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return nil
}
}
}
}
@@ -95,33 +90,3 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
return systemops.MsgToRoute(msg)
}
// waitReadable blocks until fd has data to read, or ctx is cancelled.
func waitReadable(ctx context.Context, fd int) error {
var fdset unix.FdSet
if fd < 0 || fd/unix.NFDBITS >= len(fdset.Bits) {
return fmt.Errorf("fd %d out of range for FdSet", fd)
}
for {
if err := ctx.Err(); err != nil {
return err
}
fdset = unix.FdSet{}
fdset.Set(fd)
// Use a 1-second timeout so we can re-check ctx periodically.
tv := unix.Timeval{Sec: 1}
n, err := unix.Select(fd+1, &fdset, nil, nil, &tv)
if err != nil {
if errors.Is(err, unix.EINTR) {
continue
}
return fmt.Errorf("select on routing socket: %w", err)
}
if n > 0 {
return nil
}
// timeout — loop back and re-check ctx
}
}

View File

@@ -3,6 +3,7 @@ package peer
import (
"context"
"fmt"
"math/rand"
"net"
"net/netip"
"runtime"
@@ -24,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)
type ServiceDependencies struct {
@@ -32,6 +34,7 @@ type ServiceDependencies struct {
IFaceDiscover stdnet.ExternalIFaceDiscover
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
Semaphore *semaphoregroup.SemaphoreGroup
PeerConnDispatcher *dispatcher.ConnectionDispatcher
}
@@ -108,8 +111,9 @@ type Conn struct {
wgProxyRelay wgproxy.Proxy
handshaker *Handshaker
guard *guard.Guard
wg sync.WaitGroup
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
wg sync.WaitGroup
// debug purpose
dumpState *stateDump
@@ -135,6 +139,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
@@ -149,10 +154,15 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open(engineCtx context.Context) error {
if err := conn.semaphore.Add(engineCtx); err != nil {
return err
}
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.opened {
conn.semaphore.Done()
return nil
}
@@ -163,6 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
conn.semaphore.Done()
return err
}
conn.workerICE = workerICE
@@ -196,6 +207,10 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()
conn.opened = true
@@ -655,6 +670,19 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
}
}
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
maxWait := 300
duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond
timeout := time.NewTimer(duration)
defer timeout.Stop()
select {
case <-ctx.Done():
case <-timeout.C:
}
}
func (conn *Conn) isRelayed() bool {
switch conn.currentConnPriority {
case conntype.Relay, conntype.ICETurn:

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/util"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)
var testDispatcher = dispatcher.NewConnectionDispatcher()
@@ -52,6 +53,7 @@ func TestConn_GetKey(t *testing.T) {
sd := ServiceDependencies{
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
@@ -69,6 +71,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
@@ -107,6 +110,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)

View File

@@ -198,7 +198,7 @@ func getConfigDirForUser(username string) (string, error) {
configDir := filepath.Join(DefaultConfigPathDir, username)
if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0700); err != nil {
if err := os.MkdirAll(configDir, 0600); err != nil {
return "", err
}
}
@@ -206,15 +206,9 @@ func getConfigDirForUser(username string) (string, error) {
return configDir, nil
}
func fileExists(path string) (bool, error) {
func fileExists(path string) bool {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
return !os.IsNotExist(err)
}
// createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -641,11 +635,7 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
// UpdateConfig update existing configuration according to input configuration and return with the configuration
func UpdateConfig(input ConfigInput) (*Config, error) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
if !fileExists(input.ConfigPath) {
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
}
@@ -654,11 +644,7 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
// UpdateOrCreateConfig reads existing config or generates a new one
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
if !fileExists(input.ConfigPath) {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {
@@ -671,7 +657,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
err = util.EnforcePermission(input.ConfigPath)
err := util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
@@ -798,12 +784,7 @@ func ReadConfig(configPath string) (*Config, error) {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
configExists, err := fileExists(configPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if configExists {
if fileExists(configPath) {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
@@ -850,11 +831,7 @@ func DirectWriteOutConfig(path string, config *Config) error {
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
if !fileExists(input.ConfigPath) {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {

View File

@@ -256,11 +256,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
if fileExists(profPath) {
return ErrProfileAlreadyExists
}
@@ -289,11 +285,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if !profileExists {
if !fileExists(profPath) {
return ErrProfileNotFound
}

View File

@@ -20,11 +20,7 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er
}
stateFile := filepath.Join(configDir, profileName+".state.json")
stateFileExists, err := fileExists(stateFile)
if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
}
if !stateFileExists {
if !fileExists(stateFile) {
return nil, errors.New("profile state file does not exist")
}

View File

@@ -263,14 +263,8 @@ func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, pe
case <-closer:
return
case routerStates := <-subscription.Events():
select {
case peerStateUpdate <- routerStates:
log.Debugf("triggered route state update for Peer: %s", peerKey)
case <-ctx.Done():
return
case <-closer:
return
}
peerStateUpdate <- routerStates
log.Debugf("triggered route state update for Peer: %s", peerKey)
}
}
}

View File

@@ -1,80 +0,0 @@
package handler
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
)
type Agent interface {
Up(ctx context.Context) error
Down(ctx context.Context) error
Status() (internal.StatusType, error)
}
type SleepHandler struct {
agent Agent
mu sync.Mutex
// sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake
sleepTriggeredDown bool
}
func New(agent Agent) *SleepHandler {
return &SleepHandler{
agent: agent,
}
}
func (s *SleepHandler) HandleWakeUp(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.sleepTriggeredDown {
log.Info("skipping up because wasn't sleep down")
return nil
}
// avoid other wakeup runs if sleep didn't make the computer sleep
s.sleepTriggeredDown = false
log.Info("running up after wake up")
err := s.agent.Up(ctx)
if err != nil {
log.Errorf("running up failed: %v", err)
return err
}
log.Info("running up command executed successfully")
return nil
}
func (s *SleepHandler) HandleSleep(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
status, err := s.agent.Status()
if err != nil {
return err
}
if status != internal.StatusConnecting && status != internal.StatusConnected {
log.Infof("skipping setting the agent down because status is %s", status)
return nil
}
log.Info("running down after system started sleeping")
if err = s.agent.Down(ctx); err != nil {
log.Errorf("running down failed: %v", err)
return err
}
s.sleepTriggeredDown = true
log.Info("running down executed successfully")
return nil
}

View File

@@ -1,153 +0,0 @@
package handler
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockAgent struct {
upErr error
downErr error
statusErr error
status internal.StatusType
upCalls int
}
func (m *mockAgent) Up(_ context.Context) error {
m.upCalls++
return m.upErr
}
func (m *mockAgent) Down(_ context.Context) error {
return m.downErr
}
func (m *mockAgent) Status() (internal.StatusType, error) {
return m.status, m.statusErr
}
func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) {
agent := &mockAgent{status: status}
return New(agent), agent
}
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false")
}
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
h, _ := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
// Even if Up fails, flag should be reset
_ = h.HandleWakeUp(context.Background())
assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up")
}
func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, agent.upCalls)
assert.False(t, h.sleepTriggeredDown)
}
func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
agent.upErr = errors.New("up failed")
err := h.HandleWakeUp(context.Background())
assert.ErrorIs(t, err, agent.upErr)
assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails")
}
func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
_ = h.HandleWakeUp(context.Background())
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op")
}
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Idle", internal.StatusIdle},
{"NeedsLogin", internal.StatusNeedsLogin},
{"LoginFailed", internal.StatusLoginFailed},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, _ := newHandler(tt.status)
err := h.HandleSleep(context.Background())
require.NoError(t, err)
assert.False(t, h.sleepTriggeredDown)
})
}
}
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Connecting", internal.StatusConnecting},
{"Connected", internal.StatusConnected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, _ := newHandler(tt.status)
err := h.HandleSleep(context.Background())
require.NoError(t, err)
assert.True(t, h.sleepTriggeredDown)
})
}
}
func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) {
agent := &mockAgent{statusErr: errors.New("status error")}
h := New(agent)
err := h.HandleSleep(context.Background())
assert.ErrorIs(t, err, agent.statusErr)
assert.False(t, h.sleepTriggeredDown)
}
func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) {
agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")}
h := New(agent)
err := h.HandleSleep(context.Background())
assert.ErrorIs(t, err, agent.downErr)
assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails")
}

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v6.33.3
// protoc v6.32.1
// source: daemon.proto
package proto
@@ -88,58 +88,6 @@ func (LogLevel) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{0}
}
type ExposeProtocol int32
const (
ExposeProtocol_EXPOSE_HTTP ExposeProtocol = 0
ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1
ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2
ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3
)
// Enum value maps for ExposeProtocol.
var (
ExposeProtocol_name = map[int32]string{
0: "EXPOSE_HTTP",
1: "EXPOSE_HTTPS",
2: "EXPOSE_TCP",
3: "EXPOSE_UDP",
}
ExposeProtocol_value = map[string]int32{
"EXPOSE_HTTP": 0,
"EXPOSE_HTTPS": 1,
"EXPOSE_TCP": 2,
"EXPOSE_UDP": 3,
}
)
func (x ExposeProtocol) Enum() *ExposeProtocol {
p := new(ExposeProtocol)
*p = x
return p
}
func (x ExposeProtocol) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (ExposeProtocol) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[1].Descriptor()
}
func (ExposeProtocol) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[1]
}
func (x ExposeProtocol) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use ExposeProtocol.Descriptor instead.
func (ExposeProtocol) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{1}
}
// avoid collision with loglevel enum
type OSLifecycleRequest_CycleType int32
@@ -174,11 +122,11 @@ func (x OSLifecycleRequest_CycleType) String() string {
}
func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[2].Descriptor()
return file_daemon_proto_enumTypes[1].Descriptor()
}
func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[2]
return &file_daemon_proto_enumTypes[1]
}
func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber {
@@ -226,11 +174,11 @@ func (x SystemEvent_Severity) String() string {
}
func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[3].Descriptor()
return file_daemon_proto_enumTypes[2].Descriptor()
}
func (SystemEvent_Severity) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[3]
return &file_daemon_proto_enumTypes[2]
}
func (x SystemEvent_Severity) Number() protoreflect.EnumNumber {
@@ -281,11 +229,11 @@ func (x SystemEvent_Category) String() string {
}
func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[4].Descriptor()
return file_daemon_proto_enumTypes[3].Descriptor()
}
func (SystemEvent_Category) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[4]
return &file_daemon_proto_enumTypes[3]
}
func (x SystemEvent_Category) Number() protoreflect.EnumNumber {
@@ -5652,224 +5600,6 @@ func (x *InstallerResultResponse) GetErrorMsg() string {
return ""
}
type ExposeServiceRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Port uint32 `protobuf:"varint,1,opt,name=port,proto3" json:"port,omitempty"`
Protocol ExposeProtocol `protobuf:"varint,2,opt,name=protocol,proto3,enum=daemon.ExposeProtocol" json:"protocol,omitempty"`
Pin string `protobuf:"bytes,3,opt,name=pin,proto3" json:"pin,omitempty"`
Password string `protobuf:"bytes,4,opt,name=password,proto3" json:"password,omitempty"`
UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"`
Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"`
NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceRequest) Reset() {
*x = ExposeServiceRequest{}
mi := &file_daemon_proto_msgTypes[85]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceRequest) ProtoMessage() {}
func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[85]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead.
func (*ExposeServiceRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{85}
}
func (x *ExposeServiceRequest) GetPort() uint32 {
if x != nil {
return x.Port
}
return 0
}
func (x *ExposeServiceRequest) GetProtocol() ExposeProtocol {
if x != nil {
return x.Protocol
}
return ExposeProtocol_EXPOSE_HTTP
}
func (x *ExposeServiceRequest) GetPin() string {
if x != nil {
return x.Pin
}
return ""
}
func (x *ExposeServiceRequest) GetPassword() string {
if x != nil {
return x.Password
}
return ""
}
func (x *ExposeServiceRequest) GetUserGroups() []string {
if x != nil {
return x.UserGroups
}
return nil
}
func (x *ExposeServiceRequest) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
func (x *ExposeServiceRequest) GetNamePrefix() string {
if x != nil {
return x.NamePrefix
}
return ""
}
type ExposeServiceEvent struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to Event:
//
// *ExposeServiceEvent_Ready
Event isExposeServiceEvent_Event `protobuf_oneof:"event"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceEvent) Reset() {
*x = ExposeServiceEvent{}
mi := &file_daemon_proto_msgTypes[86]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceEvent) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceEvent) ProtoMessage() {}
func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[86]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead.
func (*ExposeServiceEvent) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{86}
}
func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event {
if x != nil {
return x.Event
}
return nil
}
func (x *ExposeServiceEvent) GetReady() *ExposeServiceReady {
if x != nil {
if x, ok := x.Event.(*ExposeServiceEvent_Ready); ok {
return x.Ready
}
}
return nil
}
type isExposeServiceEvent_Event interface {
isExposeServiceEvent_Event()
}
type ExposeServiceEvent_Ready struct {
Ready *ExposeServiceReady `protobuf:"bytes,1,opt,name=ready,proto3,oneof"`
}
func (*ExposeServiceEvent_Ready) isExposeServiceEvent_Event() {}
type ExposeServiceReady struct {
state protoimpl.MessageState `protogen:"open.v1"`
ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"`
ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"`
Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceReady) Reset() {
*x = ExposeServiceReady{}
mi := &file_daemon_proto_msgTypes[87]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceReady) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceReady) ProtoMessage() {}
func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[87]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead.
func (*ExposeServiceReady) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{87}
}
func (x *ExposeServiceReady) GetServiceName() string {
if x != nil {
return x.ServiceName
}
return ""
}
func (x *ExposeServiceReady) GetServiceUrl() string {
if x != nil {
return x.ServiceUrl
}
return ""
}
func (x *ExposeServiceReady) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -5880,7 +5610,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[89]
mi := &file_daemon_proto_msgTypes[86]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5892,7 +5622,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[89]
mi := &file_daemon_proto_msgTypes[86]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -6419,25 +6149,7 @@ const file_daemon_proto_rawDesc = "" +
"\x16InstallerResultRequest\"O\n" +
"\x17InstallerResultResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\xe6\x01\n" +
"\x14ExposeServiceRequest\x12\x12\n" +
"\x04port\x18\x01 \x01(\rR\x04port\x122\n" +
"\bprotocol\x18\x02 \x01(\x0e2\x16.daemon.ExposeProtocolR\bprotocol\x12\x10\n" +
"\x03pin\x18\x03 \x01(\tR\x03pin\x12\x1a\n" +
"\bpassword\x18\x04 \x01(\tR\bpassword\x12\x1f\n" +
"\vuser_groups\x18\x05 \x03(\tR\n" +
"userGroups\x12\x16\n" +
"\x06domain\x18\x06 \x01(\tR\x06domain\x12\x1f\n" +
"\vname_prefix\x18\a \x01(\tR\n" +
"namePrefix\"Q\n" +
"\x12ExposeServiceEvent\x122\n" +
"\x05ready\x18\x01 \x01(\v2\x1a.daemon.ExposeServiceReadyH\x00R\x05readyB\a\n" +
"\x05event\"p\n" +
"\x12ExposeServiceReady\x12!\n" +
"\fservice_name\x18\x01 \x01(\tR\vserviceName\x12\x1f\n" +
"\vservice_url\x18\x02 \x01(\tR\n" +
"serviceUrl\x12\x16\n" +
"\x06domain\x18\x03 \x01(\tR\x06domain*b\n" +
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" +
"\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" +
@@ -6446,14 +6158,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a*S\n" +
"\x0eExposeProtocol\x12\x0f\n" +
"\vEXPOSE_HTTP\x10\x00\x12\x10\n" +
"\fEXPOSE_HTTPS\x10\x01\x12\x0e\n" +
"\n" +
"EXPOSE_TCP\x10\x02\x12\x0e\n" +
"\n" +
"EXPOSE_UDP\x10\x032\xac\x15\n" +
"\x05TRACE\x10\a2\xdd\x14\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -6492,8 +6197,7 @@ const file_daemon_proto_rawDesc = "" +
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00\x12M\n" +
"\rExposeService\x12\x1c.daemon.ExposeServiceRequest\x1a\x1a.daemon.ExposeServiceEvent\"\x000\x01B\bZ\x06/protob\x06proto3"
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
var (
file_daemon_proto_rawDescOnce sync.Once
@@ -6507,222 +6211,214 @@ func file_daemon_proto_rawDescGZIP() []byte {
return file_daemon_proto_rawDescData
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 91)
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(ExposeProtocol)(0), // 1: daemon.ExposeProtocol
(OSLifecycleRequest_CycleType)(0), // 2: daemon.OSLifecycleRequest.CycleType
(SystemEvent_Severity)(0), // 3: daemon.SystemEvent.Severity
(SystemEvent_Category)(0), // 4: daemon.SystemEvent.Category
(*EmptyRequest)(nil), // 5: daemon.EmptyRequest
(*OSLifecycleRequest)(nil), // 6: daemon.OSLifecycleRequest
(*OSLifecycleResponse)(nil), // 7: daemon.OSLifecycleResponse
(*LoginRequest)(nil), // 8: daemon.LoginRequest
(*LoginResponse)(nil), // 9: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 10: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 11: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 12: daemon.UpRequest
(*UpResponse)(nil), // 13: daemon.UpResponse
(*StatusRequest)(nil), // 14: daemon.StatusRequest
(*StatusResponse)(nil), // 15: daemon.StatusResponse
(*DownRequest)(nil), // 16: daemon.DownRequest
(*DownResponse)(nil), // 17: daemon.DownResponse
(*GetConfigRequest)(nil), // 18: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 19: daemon.GetConfigResponse
(*PeerState)(nil), // 20: daemon.PeerState
(*LocalPeerState)(nil), // 21: daemon.LocalPeerState
(*SignalState)(nil), // 22: daemon.SignalState
(*ManagementState)(nil), // 23: daemon.ManagementState
(*RelayState)(nil), // 24: daemon.RelayState
(*NSGroupState)(nil), // 25: daemon.NSGroupState
(*SSHSessionInfo)(nil), // 26: daemon.SSHSessionInfo
(*SSHServerState)(nil), // 27: daemon.SSHServerState
(*FullStatus)(nil), // 28: daemon.FullStatus
(*ListNetworksRequest)(nil), // 29: daemon.ListNetworksRequest
(*ListNetworksResponse)(nil), // 30: daemon.ListNetworksResponse
(*SelectNetworksRequest)(nil), // 31: daemon.SelectNetworksRequest
(*SelectNetworksResponse)(nil), // 32: daemon.SelectNetworksResponse
(*IPList)(nil), // 33: daemon.IPList
(*Network)(nil), // 34: daemon.Network
(*PortInfo)(nil), // 35: daemon.PortInfo
(*ForwardingRule)(nil), // 36: daemon.ForwardingRule
(*ForwardingRulesResponse)(nil), // 37: daemon.ForwardingRulesResponse
(*DebugBundleRequest)(nil), // 38: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 39: daemon.DebugBundleResponse
(*GetLogLevelRequest)(nil), // 40: daemon.GetLogLevelRequest
(*GetLogLevelResponse)(nil), // 41: daemon.GetLogLevelResponse
(*SetLogLevelRequest)(nil), // 42: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 43: daemon.SetLogLevelResponse
(*State)(nil), // 44: daemon.State
(*ListStatesRequest)(nil), // 45: daemon.ListStatesRequest
(*ListStatesResponse)(nil), // 46: daemon.ListStatesResponse
(*CleanStateRequest)(nil), // 47: daemon.CleanStateRequest
(*CleanStateResponse)(nil), // 48: daemon.CleanStateResponse
(*DeleteStateRequest)(nil), // 49: daemon.DeleteStateRequest
(*DeleteStateResponse)(nil), // 50: daemon.DeleteStateResponse
(*SetSyncResponsePersistenceRequest)(nil), // 51: daemon.SetSyncResponsePersistenceRequest
(*SetSyncResponsePersistenceResponse)(nil), // 52: daemon.SetSyncResponsePersistenceResponse
(*TCPFlags)(nil), // 53: daemon.TCPFlags
(*TracePacketRequest)(nil), // 54: daemon.TracePacketRequest
(*TraceStage)(nil), // 55: daemon.TraceStage
(*TracePacketResponse)(nil), // 56: daemon.TracePacketResponse
(*SubscribeRequest)(nil), // 57: daemon.SubscribeRequest
(*SystemEvent)(nil), // 58: daemon.SystemEvent
(*GetEventsRequest)(nil), // 59: daemon.GetEventsRequest
(*GetEventsResponse)(nil), // 60: daemon.GetEventsResponse
(*SwitchProfileRequest)(nil), // 61: daemon.SwitchProfileRequest
(*SwitchProfileResponse)(nil), // 62: daemon.SwitchProfileResponse
(*SetConfigRequest)(nil), // 63: daemon.SetConfigRequest
(*SetConfigResponse)(nil), // 64: daemon.SetConfigResponse
(*AddProfileRequest)(nil), // 65: daemon.AddProfileRequest
(*AddProfileResponse)(nil), // 66: daemon.AddProfileResponse
(*RemoveProfileRequest)(nil), // 67: daemon.RemoveProfileRequest
(*RemoveProfileResponse)(nil), // 68: daemon.RemoveProfileResponse
(*ListProfilesRequest)(nil), // 69: daemon.ListProfilesRequest
(*ListProfilesResponse)(nil), // 70: daemon.ListProfilesResponse
(*Profile)(nil), // 71: daemon.Profile
(*GetActiveProfileRequest)(nil), // 72: daemon.GetActiveProfileRequest
(*GetActiveProfileResponse)(nil), // 73: daemon.GetActiveProfileResponse
(*LogoutRequest)(nil), // 74: daemon.LogoutRequest
(*LogoutResponse)(nil), // 75: daemon.LogoutResponse
(*GetFeaturesRequest)(nil), // 76: daemon.GetFeaturesRequest
(*GetFeaturesResponse)(nil), // 77: daemon.GetFeaturesResponse
(*GetPeerSSHHostKeyRequest)(nil), // 78: daemon.GetPeerSSHHostKeyRequest
(*GetPeerSSHHostKeyResponse)(nil), // 79: daemon.GetPeerSSHHostKeyResponse
(*RequestJWTAuthRequest)(nil), // 80: daemon.RequestJWTAuthRequest
(*RequestJWTAuthResponse)(nil), // 81: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 82: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 83: daemon.WaitJWTTokenResponse
(*StartCPUProfileRequest)(nil), // 84: daemon.StartCPUProfileRequest
(*StartCPUProfileResponse)(nil), // 85: daemon.StartCPUProfileResponse
(*StopCPUProfileRequest)(nil), // 86: daemon.StopCPUProfileRequest
(*StopCPUProfileResponse)(nil), // 87: daemon.StopCPUProfileResponse
(*InstallerResultRequest)(nil), // 88: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 89: daemon.InstallerResultResponse
(*ExposeServiceRequest)(nil), // 90: daemon.ExposeServiceRequest
(*ExposeServiceEvent)(nil), // 91: daemon.ExposeServiceEvent
(*ExposeServiceReady)(nil), // 92: daemon.ExposeServiceReady
nil, // 93: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 94: daemon.PortInfo.Range
nil, // 95: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 96: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 97: google.protobuf.Timestamp
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
(SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity
(SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category
(*EmptyRequest)(nil), // 4: daemon.EmptyRequest
(*OSLifecycleRequest)(nil), // 5: daemon.OSLifecycleRequest
(*OSLifecycleResponse)(nil), // 6: daemon.OSLifecycleResponse
(*LoginRequest)(nil), // 7: daemon.LoginRequest
(*LoginResponse)(nil), // 8: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 9: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 10: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 11: daemon.UpRequest
(*UpResponse)(nil), // 12: daemon.UpResponse
(*StatusRequest)(nil), // 13: daemon.StatusRequest
(*StatusResponse)(nil), // 14: daemon.StatusResponse
(*DownRequest)(nil), // 15: daemon.DownRequest
(*DownResponse)(nil), // 16: daemon.DownResponse
(*GetConfigRequest)(nil), // 17: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 18: daemon.GetConfigResponse
(*PeerState)(nil), // 19: daemon.PeerState
(*LocalPeerState)(nil), // 20: daemon.LocalPeerState
(*SignalState)(nil), // 21: daemon.SignalState
(*ManagementState)(nil), // 22: daemon.ManagementState
(*RelayState)(nil), // 23: daemon.RelayState
(*NSGroupState)(nil), // 24: daemon.NSGroupState
(*SSHSessionInfo)(nil), // 25: daemon.SSHSessionInfo
(*SSHServerState)(nil), // 26: daemon.SSHServerState
(*FullStatus)(nil), // 27: daemon.FullStatus
(*ListNetworksRequest)(nil), // 28: daemon.ListNetworksRequest
(*ListNetworksResponse)(nil), // 29: daemon.ListNetworksResponse
(*SelectNetworksRequest)(nil), // 30: daemon.SelectNetworksRequest
(*SelectNetworksResponse)(nil), // 31: daemon.SelectNetworksResponse
(*IPList)(nil), // 32: daemon.IPList
(*Network)(nil), // 33: daemon.Network
(*PortInfo)(nil), // 34: daemon.PortInfo
(*ForwardingRule)(nil), // 35: daemon.ForwardingRule
(*ForwardingRulesResponse)(nil), // 36: daemon.ForwardingRulesResponse
(*DebugBundleRequest)(nil), // 37: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 38: daemon.DebugBundleResponse
(*GetLogLevelRequest)(nil), // 39: daemon.GetLogLevelRequest
(*GetLogLevelResponse)(nil), // 40: daemon.GetLogLevelResponse
(*SetLogLevelRequest)(nil), // 41: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 42: daemon.SetLogLevelResponse
(*State)(nil), // 43: daemon.State
(*ListStatesRequest)(nil), // 44: daemon.ListStatesRequest
(*ListStatesResponse)(nil), // 45: daemon.ListStatesResponse
(*CleanStateRequest)(nil), // 46: daemon.CleanStateRequest
(*CleanStateResponse)(nil), // 47: daemon.CleanStateResponse
(*DeleteStateRequest)(nil), // 48: daemon.DeleteStateRequest
(*DeleteStateResponse)(nil), // 49: daemon.DeleteStateResponse
(*SetSyncResponsePersistenceRequest)(nil), // 50: daemon.SetSyncResponsePersistenceRequest
(*SetSyncResponsePersistenceResponse)(nil), // 51: daemon.SetSyncResponsePersistenceResponse
(*TCPFlags)(nil), // 52: daemon.TCPFlags
(*TracePacketRequest)(nil), // 53: daemon.TracePacketRequest
(*TraceStage)(nil), // 54: daemon.TraceStage
(*TracePacketResponse)(nil), // 55: daemon.TracePacketResponse
(*SubscribeRequest)(nil), // 56: daemon.SubscribeRequest
(*SystemEvent)(nil), // 57: daemon.SystemEvent
(*GetEventsRequest)(nil), // 58: daemon.GetEventsRequest
(*GetEventsResponse)(nil), // 59: daemon.GetEventsResponse
(*SwitchProfileRequest)(nil), // 60: daemon.SwitchProfileRequest
(*SwitchProfileResponse)(nil), // 61: daemon.SwitchProfileResponse
(*SetConfigRequest)(nil), // 62: daemon.SetConfigRequest
(*SetConfigResponse)(nil), // 63: daemon.SetConfigResponse
(*AddProfileRequest)(nil), // 64: daemon.AddProfileRequest
(*AddProfileResponse)(nil), // 65: daemon.AddProfileResponse
(*RemoveProfileRequest)(nil), // 66: daemon.RemoveProfileRequest
(*RemoveProfileResponse)(nil), // 67: daemon.RemoveProfileResponse
(*ListProfilesRequest)(nil), // 68: daemon.ListProfilesRequest
(*ListProfilesResponse)(nil), // 69: daemon.ListProfilesResponse
(*Profile)(nil), // 70: daemon.Profile
(*GetActiveProfileRequest)(nil), // 71: daemon.GetActiveProfileRequest
(*GetActiveProfileResponse)(nil), // 72: daemon.GetActiveProfileResponse
(*LogoutRequest)(nil), // 73: daemon.LogoutRequest
(*LogoutResponse)(nil), // 74: daemon.LogoutResponse
(*GetFeaturesRequest)(nil), // 75: daemon.GetFeaturesRequest
(*GetFeaturesResponse)(nil), // 76: daemon.GetFeaturesResponse
(*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest
(*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse
(*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse
nil, // 89: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range
nil, // 91: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 92: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
96, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
97, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
97, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
96, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
93, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
94, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
20, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
19, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
23, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
24, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
97, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
95, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
96, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
92, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
12, // 38: daemon.DaemonService.Up:input_type -> daemon.UpRequest
14, // 39: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
16, // 40: daemon.DaemonService.Down:input_type -> daemon.DownRequest
18, // 41: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
29, // 42: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
31, // 43: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
31, // 44: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
5, // 45: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
38, // 46: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
40, // 47: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
42, // 48: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
45, // 49: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
47, // 50: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
49, // 51: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
51, // 52: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
54, // 53: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
57, // 54: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
59, // 55: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
61, // 56: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
63, // 57: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
65, // 58: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
67, // 59: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
69, // 60: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
78, // 64: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
80, // 65: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
82, // 66: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
84, // 67: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
86, // 68: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
6, // 69: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
88, // 70: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
90, // 71: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
9, // 72: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
11, // 73: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
13, // 74: daemon.DaemonService.Up:output_type -> daemon.UpResponse
15, // 75: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
17, // 76: daemon.DaemonService.Down:output_type -> daemon.DownResponse
19, // 77: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
30, // 78: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
32, // 79: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
32, // 80: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
37, // 81: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
39, // 82: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
41, // 83: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
43, // 84: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
46, // 85: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
48, // 86: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
50, // 87: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
52, // 88: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
56, // 89: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
58, // 90: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
60, // 91: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
62, // 92: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
64, // 93: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
66, // 94: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
68, // 95: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
70, // 96: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
73, // 97: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
75, // 98: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
77, // 99: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
79, // 100: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
81, // 101: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
83, // 102: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
85, // 103: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
87, // 104: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
7, // 105: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
89, // 106: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
91, // 107: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
72, // [72:108] is the sub-list for method output_type
36, // [36:72] is the sub-list for method input_type
36, // [36:36] is the sub-list for extension type_name
36, // [36:36] is the sub-list for extension extendee
0, // [0:36] is the sub-list for field type_name
43, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
52, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
9, // 35: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
11, // 36: daemon.DaemonService.Up:input_type -> daemon.UpRequest
13, // 37: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
15, // 38: daemon.DaemonService.Down:input_type -> daemon.DownRequest
17, // 39: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
28, // 40: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
30, // 41: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
30, // 42: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
4, // 43: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
37, // 44: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
39, // 45: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
41, // 46: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
44, // 47: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
46, // 48: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
48, // 49: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
50, // 50: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
53, // 51: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
56, // 52: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
58, // 53: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
60, // 54: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
62, // 55: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
64, // 56: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
66, // 57: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
68, // 58: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
71, // 59: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
73, // 60: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
75, // 61: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
69, // [69:104] is the sub-list for method output_type
34, // [34:69] is the sub-list for method input_type
34, // [34:34] is the sub-list for extension type_name
34, // [34:34] is the sub-list for extension extendee
0, // [0:34] is the sub-list for field type_name
}
func init() { file_daemon_proto_init() }
@@ -6743,16 +6439,13 @@ func file_daemon_proto_init() {
file_daemon_proto_msgTypes[58].OneofWrappers = []any{}
file_daemon_proto_msgTypes[69].OneofWrappers = []any{}
file_daemon_proto_msgTypes[75].OneofWrappers = []any{}
file_daemon_proto_msgTypes[86].OneofWrappers = []any{
(*ExposeServiceEvent_Ready)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 5,
NumMessages: 91,
NumEnums: 4,
NumMessages: 88,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -103,9 +103,6 @@ service DaemonService {
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
// ExposeService exposes a local port via the NetBird reverse proxy
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
}
@@ -804,32 +801,3 @@ message InstallerResultResponse {
bool success = 1;
string errorMsg = 2;
}
enum ExposeProtocol {
EXPOSE_HTTP = 0;
EXPOSE_HTTPS = 1;
EXPOSE_TCP = 2;
EXPOSE_UDP = 3;
}
message ExposeServiceRequest {
uint32 port = 1;
ExposeProtocol protocol = 2;
string pin = 3;
string password = 4;
repeated string user_groups = 5;
string domain = 6;
string name_prefix = 7;
}
message ExposeServiceEvent {
oneof event {
ExposeServiceReady ready = 1;
}
}
message ExposeServiceReady {
string service_name = 1;
string service_url = 2;
string domain = 3;
}

View File

@@ -76,8 +76,6 @@ type DaemonServiceClient interface {
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error)
}
type daemonServiceClient struct {
@@ -426,38 +424,6 @@ func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *Instal
return out, nil
}
func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) {
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/ExposeService", opts...)
if err != nil {
return nil, err
}
x := &daemonServiceExposeServiceClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type DaemonService_ExposeServiceClient interface {
Recv() (*ExposeServiceEvent, error)
grpc.ClientStream
}
type daemonServiceExposeServiceClient struct {
grpc.ClientStream
}
func (x *daemonServiceExposeServiceClient) Recv() (*ExposeServiceEvent, error) {
m := new(ExposeServiceEvent)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -520,8 +486,6 @@ type DaemonServiceServer interface {
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -634,9 +598,6 @@ func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLi
func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
}
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error {
return status.Errorf(codes.Unimplemented, "method ExposeService not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1283,27 +1244,6 @@ func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Cont
return interceptor(ctx, in, info, handler)
}
func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(ExposeServiceRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(DaemonServiceServer).ExposeService(m, &daemonServiceExposeServiceServer{stream})
}
type DaemonService_ExposeServiceServer interface {
Send(*ExposeServiceEvent) error
grpc.ServerStream
}
type daemonServiceExposeServiceServer struct {
grpc.ServerStream
}
func (x *daemonServiceExposeServiceServer) Send(m *ExposeServiceEvent) error {
return x.ServerStream.SendMsg(m)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1454,11 +1394,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
Handler: _DaemonService_SubscribeEvents_Handler,
ServerStreams: true,
},
{
StreamName: "ExposeService",
Handler: _DaemonService_ExposeService_Handler,
ServerStreams: true,
},
},
Metadata: "daemon.proto",
}

View File

@@ -0,0 +1,77 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
switch req.GetType() {
case proto.OSLifecycleRequest_WAKEUP:
return s.handleWakeUp(callerCtx)
case proto.OSLifecycleRequest_SLEEP:
return s.handleSleep(callerCtx)
default:
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
}
return &proto.OSLifecycleResponse{}, nil
}
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
if !s.sleepTriggeredDown.Load() {
log.Info("skipping up because wasn't sleep down")
return &proto.OSLifecycleResponse{}, nil
}
// avoid other wakeup runs if sleep didn't make the computer sleep
s.sleepTriggeredDown.Store(false)
log.Info("running up after wake up")
_, err := s.Up(callerCtx, &proto.UpRequest{})
if err != nil {
log.Errorf("running up failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
log.Info("running up command executed successfully")
return &proto.OSLifecycleResponse{}, nil
}
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
s.mutex.Lock()
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, err
}
if status != internal.StatusConnecting && status != internal.StatusConnected {
log.Infof("skipping setting the agent down because status is %s", status)
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, nil
}
s.mutex.Unlock()
log.Info("running down after system started sleeping")
_, err = s.Down(callerCtx, &proto.DownRequest{})
if err != nil {
log.Errorf("running down failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
s.sleepTriggeredDown.Store(true)
log.Info("running down executed successfully")
return &proto.OSLifecycleResponse{}, nil
}

View File

@@ -0,0 +1,219 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
func newTestServer() *Server {
ctx := internal.CtxInitState(context.Background())
return &Server{
rootCtx: ctx,
statusRecorder: peer.NewRecorder(""),
}
}
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
s := newTestServer()
// sleepTriggeredDown is false by default
assert.False(t, s.sleepTriggeredDown.Load())
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusNeedsLogin)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnecting)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnected)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
}
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
s := newTestServer()
// Manually set the flag to simulate prior sleep down
s.sleepTriggeredDown.Store(true)
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
}
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
s := newTestServer()
// First wakeup without prior sleep - should be no-op
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
// Simulate prior sleep
s.sleepTriggeredDown.Store(true)
// First wakeup after sleep - should reset flag
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load())
// Second wakeup - should be no-op
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
}
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
s := newTestServer()
resp, err := s.handleWakeUp(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
}
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
s := newTestServer()
s.sleepTriggeredDown.Store(true)
// Even if Up fails, flag should be reset
_, _ = s.handleWakeUp(context.Background())
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
}
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Idle", internal.StatusIdle},
{"NeedsLogin", internal.StatusNeedsLogin},
{"LoginFailed", internal.StatusLoginFailed},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
resp, err := s.handleSleep(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
})
}
}
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Connecting", internal.StatusConnecting},
{"Connected", internal.StatusConnected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.handleSleep(ctx)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.True(t, s.sleepTriggeredDown.Load())
})
}
}

View File

@@ -21,9 +21,7 @@ import (
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -87,7 +85,8 @@ type Server struct {
profilesDisabled bool
updateSettingsDisabled bool
sleepHandler *sleephandler.SleepHandler
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
sleepTriggeredDown atomic.Bool
jwtCache *jwtCache
}
@@ -101,7 +100,7 @@ type oauthAuthFlow struct {
// New server instance constructor.
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
s := &Server{
return &Server{
rootCtx: ctx,
logFile: logFile,
persistSyncResponse: true,
@@ -111,10 +110,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
updateSettingsDisabled: updateSettingsDisabled,
jwtCache: newJWTCache(),
}
agent := &serverAgent{s}
s.sleepHandler = sleephandler.New(agent)
return s
}
func (s *Server) Start() error {
@@ -641,6 +636,8 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
return s.waitForUp(callerCtx)
}
defer s.mutex.Unlock()
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
log.Warnf(errRestoreResidualState, err)
}
@@ -652,12 +649,10 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
// not in the progress or already successfully established connection.
status, err := state.Status()
if err != nil {
s.mutex.Unlock()
return nil, err
}
if status != internal.StatusIdle {
s.mutex.Unlock()
return nil, fmt.Errorf("up already in progress: current status %s", status)
}
@@ -674,20 +669,17 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.actCancel = cancel
if s.config == nil {
s.mutex.Unlock()
return nil, fmt.Errorf("config is not defined, please call login command first")
}
activeProf, err := s.profileManager.GetActiveProfileState()
if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
if msg != nil && msg.ProfileName != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
s.mutex.Unlock()
log.Errorf("failed to switch profile: %v", err)
return nil, fmt.Errorf("failed to switch profile: %w", err)
}
@@ -695,7 +687,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
activeProf, err = s.profileManager.GetActiveProfileState()
if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
@@ -704,7 +695,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
config, _, err := s.getConfig(activeProf)
if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err)
}
@@ -723,7 +713,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
}
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
s.mutex.Unlock()
return s.waitForUp(callerCtx)
}
@@ -1323,60 +1312,6 @@ func (s *Server) WaitJWTToken(
}, nil
}
// ExposeService exposes a local port via the NetBird reverse proxy.
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
s.mutex.Lock()
if !s.clientRunning {
s.mutex.Unlock()
return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first")
}
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
}
mgr := engine.GetExposeManager()
if mgr == nil {
return gstatus.Errorf(codes.Internal, "expose manager not available")
}
ctx := srv.Context()
exposeCtx, exposeCancel := context.WithTimeout(ctx, 30*time.Second)
defer exposeCancel()
mgmReq := expose.NewRequest(req)
result, err := mgr.Expose(exposeCtx, *mgmReq)
if err != nil {
return err
}
if err := srv.Send(&proto.ExposeServiceEvent{
Event: &proto.ExposeServiceEvent_Ready{
Ready: &proto.ExposeServiceReady{
ServiceName: result.ServiceName,
ServiceUrl: result.ServiceURL,
Domain: result.Domain,
},
},
}); err != nil {
return err
}
err = mgr.KeepAlive(ctx, result.Domain)
if err != nil {
return err
}
return nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false

View File

@@ -1,46 +0,0 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
type serverAgent struct {
s *Server
}
func (a *serverAgent) Up(ctx context.Context) error {
_, err := a.s.Up(ctx, &proto.UpRequest{})
return err
}
func (a *serverAgent) Down(ctx context.Context) error {
_, err := a.s.Down(ctx, &proto.DownRequest{})
return err
}
func (a *serverAgent) Status() (internal.StatusType, error) {
return internal.CtxGetState(a.s.rootCtx).Status()
}
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
switch req.GetType() {
case proto.OSLifecycleRequest_WAKEUP:
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
return &proto.OSLifecycleResponse{}, err
}
case proto.OSLifecycleRequest_SLEEP:
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
return &proto.OSLifecycleResponse{}, err
}
default:
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
}
return &proto.OSLifecycleResponse{}, nil
}

View File

@@ -19,7 +19,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
@@ -269,7 +268,7 @@ func getDefaultDaemonAddr() string {
if runtime.GOOS == "windows" {
return DefaultDaemonAddrWindows
}
return daemonaddr.ResolveUnixDaemonAddr(DefaultDaemonAddr)
return DefaultDaemonAddr
}
// DialOptions contains options for SSH connections

View File

@@ -46,10 +46,8 @@ const (
cmdSFTP = "<sftp>"
cmdNonInteractive = "<idle>"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server.
// Set to 10 minutes to accommodate identity providers like Azure Entra ID
// that backdate the iat claim by up to 5 minutes.
DefaultJWTMaxTokenAge = 10 * 60
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
)
var (

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"os"
"path"
"path/filepath"
"strings"
"time"
@@ -71,8 +70,6 @@ type ServerConfig struct {
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
Auth AuthConfig `yaml:"auth"`
Store StoreConfig `yaml:"store"`
ActivityStore StoreConfig `yaml:"activityStore"`
AuthStore StoreConfig `yaml:"authStore"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
@@ -173,8 +170,7 @@ type RelaysConfig struct {
type StoreConfig struct {
Engine string `yaml:"engine"`
EncryptionKey string `yaml:"encryptionKey"`
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
File string `yaml:"file"` // SQLite database file path (optional, defaults to dataDir)
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
}
// ReverseProxyConfig contains reverse proxy settings
@@ -536,74 +532,6 @@ func stripSignalProtocol(uri string) string {
return uri
}
func buildRelayConfig(relays RelaysConfig) (*nbconfig.Relay, error) {
var ttl time.Duration
if relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", relays.CredentialsTTL, err)
}
}
return &nbconfig.Relay{
Addresses: relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: relays.Secret,
}, nil
}
// buildEmbeddedIdPConfig builds the embedded IdP configuration.
// authStore overrides auth.storage when set.
func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.EmbeddedIdPConfig, error) {
authStorageType := mgmt.Auth.Storage.Type
authStorageDSN := c.Server.AuthStore.DSN
if c.Server.AuthStore.Engine != "" {
authStorageType = c.Server.AuthStore.Engine
}
if authStorageType == "" {
authStorageType = "sqlite3"
}
authStorageFile := ""
if authStorageType == "postgres" {
if authStorageDSN == "" {
return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres")
}
} else {
authStorageFile = path.Join(mgmt.DataDir, "idp.db")
if c.Server.AuthStore.File != "" {
authStorageFile = c.Server.AuthStore.File
if !filepath.IsAbs(authStorageFile) {
authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile)
}
}
}
cfg := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: authStorageType,
Config: idp.EmbeddedStorageTypeConfig{
File: authStorageFile,
DSN: authStorageDSN,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
cfg.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password,
}
}
return cfg, nil
}
// ToManagementConfig converts CombinedConfig to management server config
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
mgmt := c.Management
@@ -622,11 +550,19 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
// Build relay config
var relayConfig *nbconfig.Relay
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
relay, err := buildRelayConfig(mgmt.Relays)
if err != nil {
return nil, err
var ttl time.Duration
if mgmt.Relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
}
}
relayConfig = &nbconfig.Relay{
Addresses: mgmt.Relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: mgmt.Relays.Secret,
}
relayConfig = relay
}
// Build signal config
@@ -662,9 +598,31 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
httpConfig := &nbconfig.HttpServerConfig{}
// Build embedded IDP config (always enabled in combined server)
embeddedIdP, err := c.buildEmbeddedIdPConfig(mgmt)
if err != nil {
return nil, err
storageFile := mgmt.Auth.Storage.File
if storageFile == "" {
storageFile = path.Join(mgmt.DataDir, "idp.db")
}
embeddedIdP := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: mgmt.Auth.Storage.Type,
Config: idp.EmbeddedStorageTypeConfig{
File: storageFile,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
embeddedIdP.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
}
}
// Set HTTP config fields for embedded IDP

View File

@@ -140,23 +140,6 @@ func initializeConfig() error {
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := config.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
if engine := config.Server.ActivityStore.Engine; engine != "" {
engineLower := strings.ToLower(engine)
if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" {
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
}
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
if dsn := config.Server.ActivityStore.DSN; dsn != "" {
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
}
}
if file := config.Server.ActivityStore.File; file != "" {
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
}
log.Infof("Starting combined NetBird server")
logConfig(config)
@@ -685,11 +668,8 @@ func logEnvVars() {
if strings.HasPrefix(env, "NB_") {
key, _, _ := strings.Cut(env, "=")
value := os.Getenv(key)
keyLower := strings.ToLower(key)
if strings.Contains(keyLower, "secret") || strings.Contains(keyLower, "key") || strings.Contains(keyLower, "password") {
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
value = maskSecret(value)
} else if strings.Contains(keyLower, "dsn") {
value = maskDSNPassword(value)
}
log.Infof(" %s=%s", key, value)
found = true

View File

@@ -42,9 +42,6 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
datadir := cfg.Management.DataDir
engine := types.Engine(cfg.Management.Store.Engine)

View File

@@ -103,19 +103,6 @@ server:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/store.db)
# Activity events store configuration (optional, defaults to sqlite in dataDir)
# activityStore:
# engine: "sqlite" # sqlite or postgres
# dsn: "" # Connection string for postgres
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/events.db)
# Auth (embedded IdP) store configuration (optional, defaults to sqlite3 in dataDir/idp.db)
# authStore:
# engine: "sqlite3" # sqlite3 or postgres
# dsn: "" # Connection string for postgres (e.g., "host=localhost port=5432 user=postgres password=postgres dbname=netbird_idp sslmode=disable")
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/idp.db)
# Reverse proxy settings (optional)
# reverseProxy:

View File

@@ -5,10 +5,7 @@ import (
"encoding/json"
"fmt"
"log/slog"
"net/url"
"os"
"strconv"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
@@ -198,175 +195,11 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
}
return (&sql.SQLite3{File: file}).Open(logger)
case "postgres":
dsn, _ := s.Config["dsn"].(string)
if dsn == "" {
return nil, fmt.Errorf("postgres storage requires 'dsn' config")
}
pg, err := parsePostgresDSN(dsn)
if err != nil {
return nil, fmt.Errorf("invalid postgres DSN: %w", err)
}
return pg.Open(logger)
default:
return nil, fmt.Errorf("unsupported storage type: %s", s.Type)
}
}
// parsePostgresDSN parses a DSN into a sql.Postgres config.
// It accepts both URI format (postgres://user:pass@host:port/dbname?sslmode=disable)
// and libpq key=value format (host=localhost port=5432 dbname=mydb), including quoted values.
func parsePostgresDSN(dsn string) (*sql.Postgres, error) {
var params map[string]string
var err error
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
params, err = parsePostgresURI(dsn)
} else {
params, err = parsePostgresKeyValue(dsn)
}
if err != nil {
return nil, err
}
host := params["host"]
if host == "" {
host = "localhost"
}
var port uint16 = 5432
if p, ok := params["port"]; ok && p != "" {
v, err := strconv.ParseUint(p, 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid port %q: %w", p, err)
}
if v == 0 {
return nil, fmt.Errorf("invalid port %q: must be non-zero", p)
}
port = uint16(v)
}
dbname := params["dbname"]
if dbname == "" {
return nil, fmt.Errorf("dbname is required in DSN")
}
pg := &sql.Postgres{
NetworkDB: sql.NetworkDB{
Host: host,
Port: port,
Database: dbname,
User: params["user"],
Password: params["password"],
},
}
if sslMode := params["sslmode"]; sslMode != "" {
switch sslMode {
case "disable", "allow", "prefer", "require", "verify-ca", "verify-full":
pg.SSL.Mode = sslMode
default:
return nil, fmt.Errorf("unsupported sslmode %q: valid values are disable, allow, prefer, require, verify-ca, verify-full", sslMode)
}
}
return pg, nil
}
// parsePostgresURI parses a postgres:// or postgresql:// URI into parameter key-value pairs.
func parsePostgresURI(dsn string) (map[string]string, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid postgres URI: %w", err)
}
params := make(map[string]string)
if u.User != nil {
params["user"] = u.User.Username()
if p, ok := u.User.Password(); ok {
params["password"] = p
}
}
if u.Hostname() != "" {
params["host"] = u.Hostname()
}
if u.Port() != "" {
params["port"] = u.Port()
}
dbname := strings.TrimPrefix(u.Path, "/")
if dbname != "" {
params["dbname"] = dbname
}
for k, v := range u.Query() {
if len(v) > 0 {
params[k] = v[0]
}
}
return params, nil
}
// parsePostgresKeyValue parses a libpq key=value DSN string, handling single-quoted values
// (e.g., password='my pass' host=localhost).
func parsePostgresKeyValue(dsn string) (map[string]string, error) {
params := make(map[string]string)
s := strings.TrimSpace(dsn)
for s != "" {
eqIdx := strings.IndexByte(s, '=')
if eqIdx < 0 {
break
}
key := strings.TrimSpace(s[:eqIdx])
value, rest, err := parseDSNValue(s[eqIdx+1:])
if err != nil {
return nil, fmt.Errorf("%w for key %q", err, key)
}
params[key] = value
s = strings.TrimSpace(rest)
}
return params, nil
}
// parseDSNValue parses the next value from a libpq key=value string positioned after the '='.
// It returns the parsed value and the remaining unparsed string.
func parseDSNValue(s string) (value, rest string, err error) {
if len(s) > 0 && s[0] == '\'' {
return parseQuotedDSNValue(s[1:])
}
// Unquoted value: read until whitespace.
idx := strings.IndexAny(s, " \t\n")
if idx < 0 {
return s, "", nil
}
return s[:idx], s[idx:], nil
}
// parseQuotedDSNValue parses a single-quoted value starting after the opening quote.
// Libpq uses ” to represent a literal single quote inside quoted values.
func parseQuotedDSNValue(s string) (value, rest string, err error) {
var buf strings.Builder
for len(s) > 0 {
if s[0] == '\'' {
if len(s) > 1 && s[1] == '\'' {
buf.WriteByte('\'')
s = s[2:]
continue
}
return buf.String(), s[1:], nil
}
buf.WriteByte(s[0])
s = s[1:]
}
return "", "", fmt.Errorf("unterminated quoted value")
}
// Validate validates the configuration
func (c *YAMLConfig) Validate() error {
if c.Issuer == "" {

View File

@@ -63,8 +63,6 @@ type Controller struct {
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
}
type bufferUpdate struct {
@@ -87,12 +85,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
newNetworkMapBuilder = false
}
compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err)
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
@@ -116,8 +108,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
}
}
@@ -240,12 +230,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountID):
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
@@ -368,12 +355,9 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountId):
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
@@ -495,12 +479,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -875,12 +854,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -9,5 +9,4 @@ type Manager interface {
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
GetClusterDomains() []string
}

View File

@@ -228,18 +228,6 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
}
}
// GetClusterDomains returns a list of proxy cluster domains.
func (m Manager) GetClusterDomains() []string {
if m.proxyManager == nil {
return nil
}
addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background())
if err != nil {
return nil
}
return addresses
}
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.

View File

@@ -1,12 +1,8 @@
package proxy
//go:generate go run github.com/golang/mock/mockgen -package proxy -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"time"
"github.com/netbirdio/netbird/shared/management/proto"
)
// Manager defines the interface for proxy operations
@@ -17,20 +13,3 @@ type Manager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
}
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
type OIDCValidationConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
// Controller is responsible for managing proxy clusters and routing service updates.
type Controller interface {
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
GetOIDCValidationConfig() OIDCValidationConfig
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
}

View File

@@ -5,7 +5,6 @@ import (
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
@@ -20,21 +19,14 @@ type store interface {
// Manager handles all proxy operations
type Manager struct {
store store
metrics *metrics
store store
}
// NewManager creates a new proxy Manager
func NewManager(store store, meter metric.Meter) (*Manager, error) {
m, err := newMetrics(meter)
if err != nil {
return nil, err
func NewManager(store store) Manager {
return Manager{
store: store,
}
return &Manager{
store: store,
metrics: m,
}, nil
}
// Connect registers a new proxy connection in the database
@@ -91,7 +83,6 @@ func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
}
m.metrics.IncrementProxyHeartbeatCount()
return nil
}

View File

@@ -1,74 +0,0 @@
package manager
import (
"context"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
type metrics struct {
proxyConnectionCount metric.Int64UpDownCounter
serviceUpdateSendCount metric.Int64Counter
proxyHeartbeatCount metric.Int64Counter
}
func newMetrics(meter metric.Meter) (*metrics, error) {
proxyConnectionCount, err := meter.Int64UpDownCounter(
"management_proxy_connection_count",
metric.WithDescription("Number of active proxy connections"),
metric.WithUnit("{connection}"),
)
if err != nil {
return nil, err
}
serviceUpdateSendCount, err := meter.Int64Counter(
"management_proxy_service_update_send_count",
metric.WithDescription("Total number of service updates sent to proxies"),
metric.WithUnit("{update}"),
)
if err != nil {
return nil, err
}
proxyHeartbeatCount, err := meter.Int64Counter(
"management_proxy_heartbeat_count",
metric.WithDescription("Total number of proxy heartbeats received"),
metric.WithUnit("{heartbeat}"),
)
if err != nil {
return nil, err
}
return &metrics{
proxyConnectionCount: proxyConnectionCount,
serviceUpdateSendCount: serviceUpdateSendCount,
proxyHeartbeatCount: proxyHeartbeatCount,
}, nil
}
func (m *metrics) IncrementProxyConnectionCount(clusterAddr string) {
m.proxyConnectionCount.Add(context.Background(), 1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) DecrementProxyConnectionCount(clusterAddr string) {
m.proxyConnectionCount.Add(context.Background(), -1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) IncrementServiceUpdateSendCount(clusterAddr string) {
m.serviceUpdateSendCount.Add(context.Background(), 1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) IncrementProxyHeartbeatCount() {
m.proxyHeartbeatCount.Add(context.Background(), 1)
}

View File

@@ -1,199 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./manager.go
// Package proxy is a generated GoMock package.
package proxy
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
proto "github.com/netbirdio/netbird/shared/management/proto"
)
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CleanupStale mocks base method.
func (m *MockManager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupStale", ctx, inactivityDuration)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupStale indicates an expected call of CleanupStale.
func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error)
return ret0
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
}
// Disconnect mocks base method.
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// Disconnect indicates an expected call of Disconnect.
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
}
// GetActiveClusterAddresses mocks base method.
func (m *MockManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusterAddresses", ctx)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusterAddresses indicates an expected call of GetActiveClusterAddresses.
func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// Heartbeat indicates an expected call of Heartbeat.
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
}
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller
recorder *MockControllerMockRecorder
}
// MockControllerMockRecorder is the mock recorder for MockController.
type MockControllerMockRecorder struct {
mock *MockController
}
// NewMockController creates a new mock instance.
func NewMockController(ctrl *gomock.Controller) *MockController {
mock := &MockController{ctrl: ctrl}
mock.recorder = &MockControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockController) EXPECT() *MockControllerMockRecorder {
return m.recorder
}
// GetOIDCValidationConfig mocks base method.
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
ret0, _ := ret[0].(OIDCValidationConfig)
return ret0
}
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig))
}
// GetProxiesForCluster mocks base method.
func (m *MockController) GetProxiesForCluster(clusterAddr string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
ret0, _ := ret[0].([]string)
return ret0
}
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
func (mr *MockControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockController)(nil).GetProxiesForCluster), clusterAddr)
}
// RegisterProxyToCluster mocks base method.
func (m *MockController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
func (mr *MockControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
}
// SendServiceUpdateToCluster mocks base method.
func (m *MockController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SendServiceUpdateToCluster", ctx, accountID, update, clusterAddr)
}
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
func (mr *MockControllerMockRecorder) SendServiceUpdateToCluster(ctx, accountID, update, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockController)(nil).SendServiceUpdateToCluster), ctx, accountID, update, clusterAddr)
}
// UnregisterProxyFromCluster mocks base method.
func (m *MockController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
func (mr *MockControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
}

View File

@@ -4,6 +4,8 @@ package service
import (
"context"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Manager interface {
@@ -12,7 +14,6 @@ type Manager interface {
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
DeleteAllServices(ctx context.Context, accountID, userID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
@@ -21,8 +22,13 @@ type Manager interface {
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error)
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StartExposeReaper(ctx context.Context)
}
// ProxyController is responsible for managing proxy clusters and routing service updates.
type ProxyController interface {
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
GetOIDCValidationConfig() OIDCValidationConfig
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
}

View File

@@ -9,6 +9,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
proto "github.com/netbirdio/netbird/shared/management/proto"
)
// MockManager is a mock of Manager interface.
@@ -49,35 +50,6 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
}
// CreateServiceFromPeer mocks base method.
func (m *MockManager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateServiceFromPeer", ctx, accountID, peerID, req)
ret0, _ := ret[0].(*ExposeServiceResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateServiceFromPeer indicates an expected call of CreateServiceFromPeer.
func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, req interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
}
// DeleteAllServices mocks base method.
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAllServices indicates an expected call of DeleteAllServices.
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
@@ -210,20 +182,6 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
}
// RenewServiceFromPeer mocks base method.
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer.
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain)
}
// SetCertificateIssuedAt mocks base method.
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
@@ -252,32 +210,6 @@ func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status i
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
}
// StartExposeReaper mocks base method.
func (m *MockManager) StartExposeReaper(ctx context.Context) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "StartExposeReaper", ctx)
}
// StartExposeReaper indicates an expected call of StartExposeReaper.
func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartExposeReaper", reflect.TypeOf((*MockManager)(nil).StartExposeReaper), ctx)
}
// StopServiceFromPeer mocks base method.
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// StopServiceFromPeer indicates an expected call of StopServiceFromPeer.
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain)
}
// UpdateService mocks base method.
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()
@@ -292,3 +224,94 @@ func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
}
// MockProxyController is a mock of ProxyController interface.
type MockProxyController struct {
ctrl *gomock.Controller
recorder *MockProxyControllerMockRecorder
}
// MockProxyControllerMockRecorder is the mock recorder for MockProxyController.
type MockProxyControllerMockRecorder struct {
mock *MockProxyController
}
// NewMockProxyController creates a new mock instance.
func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController {
mock := &MockProxyController{ctrl: ctrl}
mock.recorder = &MockProxyControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder {
return m.recorder
}
// GetOIDCValidationConfig mocks base method.
func (m *MockProxyController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
ret0, _ := ret[0].(OIDCValidationConfig)
return ret0
}
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
func (mr *MockProxyControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockProxyController)(nil).GetOIDCValidationConfig))
}
// GetProxiesForCluster mocks base method.
func (m *MockProxyController) GetProxiesForCluster(clusterAddr string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
ret0, _ := ret[0].([]string)
return ret0
}
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
func (mr *MockProxyControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockProxyController)(nil).GetProxiesForCluster), clusterAddr)
}
// RegisterProxyToCluster mocks base method.
func (m *MockProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
func (mr *MockProxyControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockProxyController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
}
// SendServiceUpdateToCluster mocks base method.
func (m *MockProxyController) SendServiceUpdateToCluster(accountID string, update *proto.ProxyMapping, clusterAddr string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SendServiceUpdateToCluster", accountID, update, clusterAddr)
}
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
func (mr *MockProxyControllerMockRecorder) SendServiceUpdateToCluster(accountID, update, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockProxyController)(nil).SendServiceUpdateToCluster), accountID, update, clusterAddr)
}
// UnregisterProxyFromCluster mocks base method.
func (m *MockProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
func (mr *MockProxyControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockProxyController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
}

View File

@@ -5,75 +5,61 @@ import (
"sync"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/shared/management/proto"
)
// GRPCController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
type GRPCController struct {
// GRPCProxyController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
type GRPCProxyController struct {
proxyGRPCServer *nbgrpc.ProxyServiceServer
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
metrics *metrics
}
// NewGRPCController creates a new GRPCController.
func NewGRPCController(proxyGRPCServer *nbgrpc.ProxyServiceServer, meter metric.Meter) (*GRPCController, error) {
m, err := newMetrics(meter)
if err != nil {
return nil, err
}
return &GRPCController{
// NewGRPCProxyController creates a new GRPCProxyController.
func NewGRPCProxyController(proxyGRPCServer *nbgrpc.ProxyServiceServer) *GRPCProxyController {
return &GRPCProxyController{
proxyGRPCServer: proxyGRPCServer,
metrics: m,
}, nil
}
}
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
func (c *GRPCController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
func (c *GRPCProxyController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
c.metrics.IncrementServiceUpdateSendCount(clusterAddr)
}
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
func (c *GRPCController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
func (c *GRPCProxyController) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
return c.proxyGRPCServer.GetOIDCValidationConfig()
}
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
func (c *GRPCController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
func (c *GRPCProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
c.metrics.IncrementProxyConnectionCount(clusterAddr)
return nil
}
// UnregisterProxyFromCluster removes a proxy from a cluster.
func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
func (c *GRPCProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
c.metrics.DecrementProxyConnectionCount(clusterAddr)
}
return nil
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
func (c *GRPCProxyController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)
if !ok {
return nil

View File

@@ -1,65 +0,0 @@
package manager
import (
"context"
"math/rand/v2"
"time"
"github.com/netbirdio/netbird/shared/management/status"
log "github.com/sirupsen/logrus"
)
const (
exposeTTL = 90 * time.Second
exposeReapInterval = 30 * time.Second
maxExposesPerPeer = 10
exposeReapBatch = 100
)
type exposeReaper struct {
manager *Manager
}
// StartExposeReaper starts a background goroutine that reaps expired ephemeral services from the DB.
func (r *exposeReaper) StartExposeReaper(ctx context.Context) {
go func() {
// start with a random delay
rn := rand.IntN(10)
time.Sleep(time.Duration(rn) * time.Second)
ticker := time.NewTicker(exposeReapInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
r.reapExpiredExposes(ctx)
}
}
}()
}
func (r *exposeReaper) reapExpiredExposes(ctx context.Context) {
expired, err := r.manager.store.GetExpiredEphemeralServices(ctx, exposeTTL, exposeReapBatch)
if err != nil {
log.Errorf("failed to get expired ephemeral services: %v", err)
return
}
for _, svc := range expired {
log.Infof("reaping expired expose session for peer %s, domain %s", svc.SourcePeer, svc.Domain)
err := r.manager.deleteExpiredPeerService(ctx, svc.AccountID, svc.SourcePeer, svc.ID)
if err == nil {
continue
}
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
log.Debugf("service %s was already deleted by another instance", svc.Domain)
} else {
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", svc.Domain, err)
}
}
}

View File

@@ -1,208 +0,0 @@
package manager
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/store"
)
func TestReapExpiredExposes(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
// Manually expire the service by backdating meta_last_renewed_at
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Create a non-expired service
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8081,
Protocol: "http",
})
require.NoError(t, err)
mgr.exposeReaper.reapExpiredExposes(ctx)
// Expired service should be deleted
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.Error(t, err, "expired service should be deleted")
// Non-expired service should remain
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp2.Domain)
require.NoError(t, err, "active service should remain")
}
func TestReapAlreadyDeletedService(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Delete the service before reaping
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
// Reaping should handle the already-deleted service gracefully
mgr.exposeReaper.reapExpiredExposes(ctx)
}
func TestConcurrentReapAndRenew(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
require.NoError(t, err)
}
// Expire all services
services, err := testStore.GetAccountServices(ctx, store.LockingStrengthNone, testAccountID)
require.NoError(t, err)
for _, svc := range services {
if svc.Source == rpservice.SourceEphemeral {
expireEphemeralService(t, testStore, testAccountID, svc.Domain)
}
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
mgr.exposeReaper.reapExpiredExposes(ctx)
}()
go func() {
defer wg.Done()
_, _ = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
}()
wg.Wait()
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count, "all expired services should be reaped")
}
func TestRenewEphemeralService(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
t.Run("renew succeeds for active service", func(t *testing.T) {
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8082,
Protocol: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
})
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
require.Error(t, err)
assert.Contains(t, err.Error(), "no active expose session")
})
}
func TestCountAndExistsEphemeralServices(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8083,
Protocol: "http",
})
require.NoError(t, err)
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
exists, err := mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
assert.True(t, exists, "service should exist")
exists, err = mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, "no-such.domain")
require.NoError(t, err)
assert.False(t, exists, "non-existent service should not exist")
}
func TestMaxExposesPerPeerEnforced(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
for i := range maxExposesPerPeer {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8090 + i,
Protocol: "http",
})
require.NoError(t, err, "expose %d should succeed", i)
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9999,
Protocol: "http",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
}
func TestReapSkipsRenewedService(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8086,
Protocol: "http",
})
require.NoError(t, err)
// Expire the service
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Renew it before the reaper runs
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
// Reaper should skip it because the re-check sees a fresh timestamp
mgr.exposeReaper.reapExpiredExposes(ctx)
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err, "renewed service should survive reaping")
}
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
t.Helper()
svc, err := s.GetServiceByDomain(context.Background(), accountID, domain)
require.NoError(t, err)
expired := time.Now().Add(-2 * exposeTTL)
svc.Meta.LastRenewedAt = &expired
err = s.UpdateService(context.Background(), svc)
require.NoError(t, err)
}

View File

@@ -3,15 +3,10 @@ package manager
import (
"context"
"fmt"
"math/rand/v2"
"slices"
"time"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/account"
@@ -28,34 +23,25 @@ const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
GetClusterDomains() []string
}
type Manager struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyController proxy.Controller
proxyController service.ProxyController
clusterDeriver ClusterDeriver
exposeReaper *exposeReaper
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
mgr := &Manager{
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController service.ProxyController, clusterDeriver ClusterDeriver) *Manager {
return &Manager{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyController: proxyController,
clusterDeriver: clusterDeriver,
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr
}
// StartExposeReaper starts the background goroutine that reaps expired ephemeral services.
func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeReaper.StartExposeReaper(ctx)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
@@ -215,52 +201,6 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, servi
})
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Lock the peer row to serialize concurrent creates for the same peer.
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
// table returns no rows and acquires no locks, allowing concurrent inserts
// to bypass the per-peer limit.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
if err := m.checkDomainAvailable(ctx, transaction, accountID, svc.Domain, ""); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
return nil
})
}
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
@@ -444,10 +384,6 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
return err
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
@@ -467,47 +403,6 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
return nil
}
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var services []*service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return err
}
for _, svc := range services {
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
}
return nil
})
if err != nil {
return err
}
oidcCfg := m.proxyController.GetOIDCValidationConfig()
for _, svc := range services {
m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
@@ -517,8 +412,7 @@ func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, service
return fmt.Errorf("failed to get service: %w", err)
}
now := time.Now()
service.Meta.CertificateIssuedAt = &now
service.Meta.CertificateIssuedAt = time.Now()
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
@@ -642,287 +536,3 @@ func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string,
return target.ServiceID, nil
}
// validateExposePermission checks whether the peer is allowed to use the expose feature.
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return status.Errorf(status.Internal, "get account settings: %v", err)
}
if !settings.PeerExposeEnabled {
return status.Errorf(status.PermissionDenied, "peer expose is not enabled for this account")
}
if len(settings.PeerExposeGroups) == 0 {
return status.Errorf(status.PermissionDenied, "no group is set for peer expose")
}
peerGroupIDs, err := m.store.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer group IDs: %v", err)
return status.Errorf(status.Internal, "get peer groups: %v", err)
}
for _, pg := range peerGroupIDs {
if slices.Contains(settings.PeerExposeGroups, pg) {
return nil
}
}
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
}
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
if err := req.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
}
if err := m.validateExposePermission(ctx, accountID, peerID); err != nil {
return nil, err
}
serviceName, err := service.GenerateExposeName(req.NamePrefix)
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
}
svc := req.ToService(accountID, peerID, serviceName)
svc.Source = service.SourceEphemeral
if svc.Domain == "" {
domain, err := m.buildRandomDomain(svc.Name)
if err != nil {
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
}
svc.Domain = domain
}
if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled {
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups)
if err != nil {
return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err)
}
svc.Auth.BearerAuth.DistributionGroups = groupIDs
}
if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil {
return nil, err
}
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, err
}
svc.SourcePeer = peerID
now := time.Now()
svc.Meta.LastRenewedAt = &now
if err := m.persistNewEphemeralService(ctx, accountID, peerID, svc); err != nil {
return nil, err
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta)
if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil {
return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return &service.ExposeServiceResponse{
ServiceName: svc.Name,
ServiceURL: "https://" + svc.Domain,
Domain: svc.Domain,
}, nil
}
func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
if len(groupNames) == 0 {
return []string{}, fmt.Errorf("no group names provided")
}
groupIDs := make([]string, 0, len(groupNames))
for _, groupName := range groupNames {
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
}
groupIDs = append(groupIDs, g.ID)
}
return groupIDs, nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
}
clusterDomains := m.clusterDeriver.GetClusterDomains()
if len(clusterDomains) == 0 {
return "", fmt.Errorf("no cluster domains found for service %s", name)
}
index := rand.IntN(len(clusterDomains))
domain := name + "." + clusterDomains[index]
return domain, nil
}
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
}
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
return err
}
return nil
}
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
activityCode := activity.PeerServiceUnexposed
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
return nil, err
}
if svc.Source != service.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if svc.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return svc, nil
}
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
var svc *service.Service
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose")
}
if svc.SourcePeer != peerID {
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
return nil
})
if err != nil {
return err
}
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
peer = nil
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// deleteExpiredPeerService deletes an ephemeral service by ID after re-checking
// that it is still expired under a row lock. This prevents deleting a service
// that was renewed between the batch query and this delete, and ensures only one
// management instance processes the deletion
func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerID, serviceID string) error {
var svc *service.Service
deleted := false
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if svc.Source != service.SourceEphemeral || svc.SourcePeer != peerID {
return status.Errorf(status.PermissionDenied, "service does not match expected ephemeral owner")
}
if svc.Meta.LastRenewedAt != nil && time.Since(*svc.Meta.LastRenewedAt) <= exposeTTL {
return nil
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
deleted = true
return nil
})
if err != nil {
return err
}
if !deleted {
return nil
}
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
peer = nil
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any {
if peer == nil {
return meta
}
meta["peer_name"] = peer.Name
if peer.IP != nil {
meta["peer_ip"] = peer.IP.String()
}
return meta
}

View File

@@ -3,28 +3,15 @@ package manager
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -368,8 +355,8 @@ func TestPreserveServiceMetadata(t *testing.T) {
mgr := &Manager{}
existing := &rpservice.Service{
Meta: rpservice.Meta{
CertificateIssuedAt: func() *time.Time { t := time.Now(); return &t }(),
Meta: rpservice.ServiceMeta{
CertificateIssuedAt: time.Now(),
Status: "active",
},
SessionPrivateKey: "private-key",
@@ -386,807 +373,3 @@ func TestPreserveServiceMetadata(t *testing.T) {
assert.Equal(t, existing.SessionPrivateKey, updated.SessionPrivateKey)
assert.Equal(t, existing.SessionPublicKey, updated.SessionPublicKey)
}
func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
ownerPeerID := "peer-owner"
otherPeerID := "peer-other"
serviceID := "service-123"
testPeer := &nbpeer.Peer{
ID: ownerPeerID,
Name: "test-peer",
IP: net.ParseIP("100.64.0.1"),
}
newEphemeralService := func() *rpservice.Service {
return &rpservice.Service{
ID: serviceID,
AccountID: accountID,
Name: "test-service",
Domain: "test.example.com",
Source: rpservice.SourceEphemeral,
SourcePeer: ownerPeerID,
}
}
newPermanentService := func() *rpservice.Service {
return &rpservice.Service{
ID: serviceID,
AccountID: accountID,
Name: "api-service",
Domain: "api.example.com",
Source: rpservice.SourcePermanent,
}
}
newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer {
t.Helper()
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(srv.Close)
return srv
}
t.Run("owner peer can delete own service", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var storedActivity activity.Activity
mockStore := store.NewMockStore(ctrl)
mockAccountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
storedActivity = activityID.(activity.Activity)
},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
}
mockStore.EXPECT().
ExecuteInTransaction(ctx, gomock.Any()).
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
return fn(txMock)
})
mockStore.EXPECT().
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil)
mgr := &Manager{
store: mockStore,
accountManager: mockAccountMgr,
proxyController: func() proxy.Controller {
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
return c
}(),
}
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed)
require.NoError(t, err)
assert.Equal(t, activity.PeerServiceUnexposed, storedActivity, "should store unexposed activity")
})
t.Run("different peer cannot delete service", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
ExecuteInTransaction(ctx, gomock.Any()).
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
return fn(txMock)
})
mgr := &Manager{
store: mockStore,
}
err := mgr.deletePeerService(ctx, accountID, otherPeerID, serviceID, activity.PeerServiceUnexposed)
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok, "should be a status error")
assert.Equal(t, status.PermissionDenied, sErr.Type(), "should be permission denied")
assert.Contains(t, err.Error(), "another peer")
})
t.Run("cannot delete API-created service", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
ExecuteInTransaction(ctx, gomock.Any()).
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newPermanentService(), nil)
return fn(txMock)
})
mgr := &Manager{
store: mockStore,
}
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed)
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok, "should be a status error")
assert.Equal(t, status.PermissionDenied, sErr.Type(), "should be permission denied")
assert.Contains(t, err.Error(), "API-created")
})
t.Run("expire uses correct activity code", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var storedActivity activity.Activity
mockStore := store.NewMockStore(ctrl)
mockAccountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
storedActivity = activityID.(activity.Activity)
},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
}
mockStore.EXPECT().
ExecuteInTransaction(ctx, gomock.Any()).
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
return fn(txMock)
})
mockStore.EXPECT().
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil)
mgr := &Manager{
store: mockStore,
accountManager: mockAccountMgr,
proxyController: func() proxy.Controller {
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
return c
}(),
}
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceExposeExpired)
require.NoError(t, err)
assert.Equal(t, activity.PeerServiceExposeExpired, storedActivity, "should store expired activity")
})
t.Run("event meta includes peer info", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var storedMeta map[string]any
mockStore := store.NewMockStore(ctrl)
mockAccountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, meta map[string]any) {
storedMeta = meta
},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
}
mockStore.EXPECT().
ExecuteInTransaction(ctx, gomock.Any()).
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
return fn(txMock)
})
mockStore.EXPECT().
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil)
mgr := &Manager{
store: mockStore,
accountManager: mockAccountMgr,
proxyController: func() proxy.Controller {
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
return c
}(),
}
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed)
require.NoError(t, err)
require.NotNil(t, storedMeta)
assert.Equal(t, "test-peer", storedMeta["peer_name"], "meta should contain peer name")
assert.Equal(t, "100.64.0.1", storedMeta["peer_ip"], "meta should contain peer IP")
assert.Equal(t, "test-service", storedMeta["name"], "meta should contain service name")
assert.Equal(t, "test.example.com", storedMeta["domain"], "meta should contain service domain")
})
}
// testClusterDeriver is a minimal ClusterDeriver that returns a fixed domain list.
type testClusterDeriver struct {
domains []string
}
func (d *testClusterDeriver) DeriveClusterFromDomain(_ context.Context, _, domain string) (string, error) {
return "test-cluster", nil
}
func (d *testClusterDeriver) GetClusterDomains() []string {
return d.domains
}
const (
testAccountID = "test-account"
testPeerID = "test-peer-1"
testGroupID = "test-group-1"
testUserID = "test-user"
)
// setupIntegrationTest creates a real SQLite store with seeded test data for integration tests.
func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
CreatedBy: testUserID,
Settings: &types.Settings{
PeerExposeEnabled: true,
PeerExposeGroups: []string{testGroupID},
},
Users: map[string]*types.User{
testUserID: {
Id: testUserID,
AccountID: testAccountID,
Role: types.UserRoleAdmin,
},
},
Peers: map[string]*nbpeer.Peer{
testPeerID: {
ID: testPeerID,
AccountID: testAccountID,
Key: "test-key",
DNSLabel: "test-peer",
Name: "test-peer",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
},
},
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
AccountID: testAccountID,
Name: "Expose Group",
},
},
})
require.NoError(t, err)
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
require.NoError(t, err)
permsMgr := permissions.NewManager(testStore)
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
},
}
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
mgr := &Manager{
store: testStore,
accountManager: accountMgr,
permissionsManager: permsMgr,
proxyController: proxyController,
clusterDeriver: &testClusterDeriver{
domains: []string{"test.netbird.io"},
},
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr, testStore
}
func Test_validateExposePermission(t *testing.T) {
ctx := context.Background()
t.Run("allowed when peer is in expose group", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
err := mgr.validateExposePermission(ctx, testAccountID, testPeerID)
assert.NoError(t, err)
})
t.Run("denied when peer is not in expose group", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
// Add a peer that is NOT in the expose group
otherPeerID := "other-peer"
err := testStore.AddPeerToAccount(ctx, &nbpeer.Peer{
ID: otherPeerID,
AccountID: testAccountID,
Key: "other-key",
DNSLabel: "other-peer",
Name: "other-peer",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{LastSeen: time.Now()},
Meta: nbpeer.PeerSystemMeta{Hostname: "other-peer"},
})
require.NoError(t, err)
err = mgr.validateExposePermission(ctx, testAccountID, otherPeerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "not in an allowed expose group")
})
t.Run("denied when expose is disabled", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
// Disable peer expose
s, err := testStore.GetAccountSettings(ctx, store.LockingStrengthNone, testAccountID)
require.NoError(t, err)
s.PeerExposeEnabled = false
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err)
err = mgr.validateExposePermission(ctx, testAccountID, testPeerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "not enabled")
})
t.Run("disallowed when no groups configured", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
// Enable expose with empty groups — no groups configured means no peer is allowed
s, err := testStore.GetAccountSettings(ctx, store.LockingStrengthNone, testAccountID)
require.NoError(t, err)
s.PeerExposeGroups = []string{}
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err)
err = mgr.validateExposePermission(ctx, testAccountID, testPeerID)
assert.Error(t, err)
})
t.Run("error when store returns error", func(t *testing.T) {
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountSettings(gomock.Any(), gomock.Any(), testAccountID).Return(nil, errors.New("store error"))
mgr := &Manager{store: mockStore}
err := mgr.validateExposePermission(ctx, testAccountID, testPeerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "get account settings")
})
}
func TestCreateServiceFromPeer(t *testing.T) {
ctx := context.Background()
t.Run("creates service with random domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
assert.NotEmpty(t, resp.ServiceName, "service name should be generated")
assert.Contains(t, resp.Domain, "test.netbird.io", "domain should use cluster domain")
assert.NotEmpty(t, resp.ServiceURL, "service URL should be set")
// Verify service is persisted in store
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err)
assert.Equal(t, resp.Domain, persisted.Domain)
assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
assert.Equal(t, testPeerID, persisted.SourcePeer, "source peer should be set")
assert.NotNil(t, persisted.Meta.LastRenewedAt, "last renewed should be set")
})
t.Run("creates service with custom domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 80,
Protocol: "http",
Domain: "example.com",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
assert.Contains(t, resp.Domain, "example.com", "should use the provided domain")
})
t.Run("validates expose permission internally", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
// Disable peer expose
s, err := testStore.GetAccountSettings(ctx, store.LockingStrengthNone, testAccountID)
require.NoError(t, err)
s.PeerExposeEnabled = false
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.Error(t, err)
assert.Contains(t, err.Error(), "not enabled")
})
t.Run("validates request fields", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 0,
Protocol: "http",
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.Error(t, err)
assert.Contains(t, err.Error(), "port")
})
}
func TestExposeServiceRequestValidate(t *testing.T) {
tests := []struct {
name string
req rpservice.ExposeServiceRequest
wantErr string
}{
{
name: "valid http request",
req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"},
wantErr: "",
},
{
name: "valid https request with pin",
req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
wantErr: "",
},
{
name: "port zero rejected",
req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "negative port rejected",
req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "port above 65535 rejected",
req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "unsupported protocol",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
wantErr: "unsupported protocol",
},
{
name: "invalid pin format",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
wantErr: "invalid pin",
},
{
name: "pin too short",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
wantErr: "invalid pin",
},
{
name: "valid 6-digit pin",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
wantErr: "",
},
{
name: "empty user group name",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
wantErr: "user group name cannot be empty",
},
{
name: "invalid name prefix",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
wantErr: "invalid name prefix",
},
{
name: "valid name prefix",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.req.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
}
})
}
t.Run("nil receiver", func(t *testing.T) {
var req *rpservice.ExposeServiceRequest
err := req.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "request cannot be nil")
})
}
func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
ctx := context.Background()
t.Run("deletes service by domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
// First create a service
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
// Delete by domain using unexported method
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false)
require.NoError(t, err)
// Verify service is deleted
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.Error(t, err, "service should be deleted")
})
t.Run("expire uses correct activity", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true)
require.NoError(t, err)
})
}
func TestStopServiceFromPeer(t *testing.T) {
ctx := context.Background()
t.Run("stops service by domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.Error(t, err, "service should be deleted")
})
}
func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
ctx := context.Background()
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(1), count, "one ephemeral service should exist after create")
svc, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err)
err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID)
require.NoError(t, err)
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete")
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9090,
Protocol: "http",
})
assert.NoError(t, err, "new expose should succeed after API delete")
}
func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) {
ctx := context.Background()
mgr, _ := setupIntegrationTest(t)
for i := range 3 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
require.NoError(t, err)
}
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(3), count, "all ephemeral services should exist")
err = mgr.DeleteAllServices(ctx, testAccountID, testUserID)
require.NoError(t, err)
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count, "all ephemeral services should be deleted after DeleteAllServices")
}
func TestRenewServiceFromPeer(t *testing.T) {
ctx := context.Background()
t.Run("renews tracked expose", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
})
t.Run("fails for untracked domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
require.Error(t, err)
})
}
func TestGetGroupIDsFromNames(t *testing.T) {
ctx := context.Background()
t.Run("resolves group names to IDs", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ids, err := mgr.getGroupIDsFromNames(ctx, testAccountID, []string{"Expose Group"})
require.NoError(t, err)
require.Len(t, ids, 1, "should return exactly one group ID")
assert.Equal(t, testGroupID, ids[0])
})
t.Run("returns error for unknown group", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
_, err := mgr.getGroupIDsFromNames(ctx, testAccountID, []string{"nonexistent"})
require.Error(t, err)
})
t.Run("returns error for empty group list", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
_, err := mgr.getGroupIDsFromNames(ctx, testAccountID, []string{})
require.Error(t, err)
assert.Contains(t, err.Error(), "no group names provided")
})
}
func TestDeleteService_DeletesTargets(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
userID := "test-user"
sqlStore, err := store.NewStore(ctx, types.SqliteStoreEngine, t.TempDir(), nil, false)
require.NoError(t, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockPerms := permissions.NewMockManager(ctrl)
mockAcct := account.NewMockManager(ctrl)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
mgr := &Manager{
store: sqlStore,
permissionsManager: mockPerms,
accountManager: mockAcct,
proxyController: proxyController,
}
service := &rpservice.Service{
ID: "service-1",
AccountID: accountID,
Domain: "test.example.com",
ProxyCluster: "cluster1",
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-1"},
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-2"},
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-3"},
},
}
err = sqlStore.CreateService(ctx, service)
require.NoError(t, err)
retrievedService, err := sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID)
require.NoError(t, err)
require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion")
mockPerms.EXPECT().
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
Return(true, nil)
mockAcct.EXPECT().
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
mockAcct.EXPECT().
UpdateAccountPeers(ctx, accountID)
err = mgr.DeleteService(ctx, accountID, userID, service.ID)
require.NoError(t, err)
_, err = sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID)
require.Error(t, err)
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.NotFound, s.Type())
targets, err := sqlStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, accountID, service.ID)
require.NoError(t, err)
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
}

View File

@@ -1,20 +1,16 @@
package service
import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"net"
"net/url"
"regexp"
"strconv"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/util/crypt"
@@ -44,9 +40,6 @@ const (
TargetTypeHost = "host"
TargetTypeDomain = "domain"
TargetTypeSubnet = "subnet"
SourcePermanent = "permanent"
SourceEphemeral = "ephemeral"
)
type Target struct {
@@ -112,11 +105,17 @@ func (a *AuthConfig) ClearSecrets() {
}
}
type Meta struct {
type OIDCValidationConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
type ServiceMeta struct {
CreatedAt time.Time
CertificateIssuedAt *time.Time
CertificateIssuedAt time.Time
Status string
LastRenewedAt *time.Time
}
type Service struct {
@@ -129,12 +128,10 @@ type Service struct {
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
SourcePeer string `gorm:"index:idx_service_source_peer"`
Auth AuthConfig `gorm:"serializer:json"`
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
}
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
@@ -159,7 +156,7 @@ func NewService(accountID, name, domain, proxyCluster string, targets []*Target,
// only be called during initial creation, not for updates.
func (s *Service) InitNewRecord() {
s.ID = xid.New().String()
s.Meta = Meta{
s.Meta = ServiceMeta{
CreatedAt: time.Now(),
Status: string(StatusPending),
}
@@ -210,8 +207,8 @@ func (s *Service) ToAPIResponse() *api.Service {
Status: api.ServiceMetaStatus(s.Meta.Status),
}
if s.Meta.CertificateIssuedAt != nil {
meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt
if !s.Meta.CertificateIssuedAt.IsZero() {
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
}
resp := &api.Service{
@@ -233,7 +230,7 @@ func (s *Service) ToAPIResponse() *api.Service {
return resp
}
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {
@@ -406,11 +403,7 @@ func (s *Service) Validate() error {
}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
}
func (s *Service) isAuthEnabled() bool {
return s.Auth.PasswordAuth != nil || s.Auth.PinAuth != nil || s.Auth.BearerAuth != nil
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
}
func (s *Service) Copy() *Service {
@@ -434,8 +427,6 @@ func (s *Service) Copy() *Service {
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,
Source: s.Source,
SourcePeer: s.SourcePeer,
}
}
@@ -470,140 +461,3 @@ func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
return nil
}
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
const alphanumCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`)
// ExposeServiceRequest contains the parameters for creating a peer-initiated expose service.
type ExposeServiceRequest struct {
NamePrefix string
Port int
Protocol string
Domain string
Pin string
Password string
UserGroups []string
}
// Validate checks all fields of the expose request.
func (r *ExposeServiceRequest) Validate() error {
if r == nil {
return errors.New("request cannot be nil")
}
if r.Port < 1 || r.Port > 65535 {
return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port)
}
if r.Protocol != "http" && r.Protocol != "https" {
return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol)
}
if r.Pin != "" && !pinRegexp.MatchString(r.Pin) {
return errors.New("invalid pin: must be exactly 6 digits")
}
for _, g := range r.UserGroups {
if g == "" {
return errors.New("user group name cannot be empty")
}
}
if r.NamePrefix != "" && !validNamePrefix.MatchString(r.NamePrefix) {
return fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", r.NamePrefix)
}
return nil
}
// ToService builds a Service from the expose request.
func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service {
service := &Service{
AccountID: accountID,
Name: serviceName,
Enabled: true,
Targets: []*Target{
{
AccountID: accountID,
Port: r.Port,
Protocol: r.Protocol,
TargetId: peerID,
TargetType: TargetTypePeer,
Enabled: true,
},
},
}
if r.Domain != "" {
service.Domain = serviceName + "." + r.Domain
}
if r.Pin != "" {
service.Auth.PinAuth = &PINAuthConfig{
Enabled: true,
Pin: r.Pin,
}
}
if r.Password != "" {
service.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: true,
Password: r.Password,
}
}
if len(r.UserGroups) > 0 {
service.Auth.BearerAuth = &BearerAuthConfig{
Enabled: true,
DistributionGroups: r.UserGroups,
}
}
return service
}
// ExposeServiceResponse contains the result of a successful peer expose creation.
type ExposeServiceResponse struct {
ServiceName string
ServiceURL string
Domain string
}
// GenerateExposeName generates a random service name for peer-exposed services.
// The prefix, if provided, must be a valid DNS label component (lowercase alphanumeric and hyphens).
func GenerateExposeName(prefix string) (string, error) {
if prefix != "" && !validNamePrefix.MatchString(prefix) {
return "", fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", prefix)
}
suffixLen := 12
if prefix != "" {
suffixLen = 4
}
suffix, err := randomAlphanumeric(suffixLen)
if err != nil {
return "", fmt.Errorf("generate random name: %w", err)
}
if prefix == "" {
return suffix, nil
}
return prefix + "-" + suffix, nil
}
func randomAlphanumeric(n int) (string, error) {
result := make([]byte, n)
charsetLen := big.NewInt(int64(len(alphanumCharset)))
for i := range result {
idx, err := rand.Int(rand.Reader, charsetLen)
if err != nil {
return "", err
}
result[i] = alphanumCharset[idx.Int64()]
}
return string(result), nil
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -110,7 +109,7 @@ func TestIsDefaultPort(t *testing.T) {
}
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
oidcConfig := proxy.OIDCValidationConfig{}
oidcConfig := OIDCValidationConfig{}
tests := []struct {
name string
@@ -203,7 +202,7 @@ func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
}
@@ -220,7 +219,7 @@ func TestToProtoMapping_OperationTypes(t *testing.T) {
}
for _, tt := range tests {
t.Run(string(tt.op), func(t *testing.T) {
pm := rp.ToProtoMapping(tt.op, "", proxy.OIDCValidationConfig{})
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
assert.Equal(t, tt.want, pm.Type)
})
}
@@ -404,146 +403,3 @@ func TestAuthConfig_ClearSecrets(t *testing.T) {
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
}
}
func TestGenerateExposeName(t *testing.T) {
t.Run("no prefix generates 12-char name", func(t *testing.T) {
name, err := GenerateExposeName("")
require.NoError(t, err)
assert.Len(t, name, 12)
assert.Regexp(t, `^[a-z0-9]+$`, name)
})
t.Run("with prefix generates prefix-XXXX", func(t *testing.T) {
name, err := GenerateExposeName("myapp")
require.NoError(t, err)
assert.True(t, strings.HasPrefix(name, "myapp-"), "name should start with prefix")
suffix := strings.TrimPrefix(name, "myapp-")
assert.Len(t, suffix, 4, "suffix should be 4 chars")
assert.Regexp(t, `^[a-z0-9]+$`, suffix)
})
t.Run("unique names", func(t *testing.T) {
names := make(map[string]bool)
for i := 0; i < 50; i++ {
name, err := GenerateExposeName("")
require.NoError(t, err)
names[name] = true
}
assert.Greater(t, len(names), 45, "should generate mostly unique names")
})
t.Run("valid prefixes", func(t *testing.T) {
validPrefixes := []string{"a", "ab", "a1", "my-app", "web-server-01", "a-b"}
for _, prefix := range validPrefixes {
name, err := GenerateExposeName(prefix)
assert.NoError(t, err, "prefix %q should be valid", prefix)
assert.True(t, strings.HasPrefix(name, prefix+"-"), "name should start with %q-", prefix)
}
})
t.Run("invalid prefixes", func(t *testing.T) {
invalidPrefixes := []string{
"-starts-with-dash",
"ends-with-dash-",
"has.dots",
"HAS-UPPER",
"has spaces",
"has/slash",
"a--",
}
for _, prefix := range invalidPrefixes {
_, err := GenerateExposeName(prefix)
assert.Error(t, err, "prefix %q should be invalid", prefix)
assert.Contains(t, err.Error(), "invalid name prefix")
}
})
}
func TestExposeServiceRequest_ToService(t *testing.T) {
t.Run("basic HTTP service", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
service := req.ToService("account-1", "peer-1", "mysvc")
assert.Equal(t, "account-1", service.AccountID)
assert.Equal(t, "mysvc", service.Name)
assert.True(t, service.Enabled)
assert.Empty(t, service.Domain, "domain should be empty when not specified")
require.Len(t, service.Targets, 1)
target := service.Targets[0]
assert.Equal(t, 8080, target.Port)
assert.Equal(t, "http", target.Protocol)
assert.Equal(t, "peer-1", target.TargetId)
assert.Equal(t, TargetTypePeer, target.TargetType)
assert.True(t, target.Enabled)
assert.Equal(t, "account-1", target.AccountID)
})
t.Run("with custom domain", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 3000,
Domain: "example.com",
}
service := req.ToService("acc", "peer", "web")
assert.Equal(t, "web.example.com", service.Domain)
})
t.Run("with PIN auth", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 80,
Pin: "1234",
}
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.PinAuth)
assert.True(t, service.Auth.PinAuth.Enabled)
assert.Equal(t, "1234", service.Auth.PinAuth.Pin)
assert.Nil(t, service.Auth.PasswordAuth)
assert.Nil(t, service.Auth.BearerAuth)
})
t.Run("with password auth", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 80,
Password: "secret",
}
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.PasswordAuth)
assert.True(t, service.Auth.PasswordAuth.Enabled)
assert.Equal(t, "secret", service.Auth.PasswordAuth.Password)
})
t.Run("with user groups (bearer auth)", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 80,
UserGroups: []string{"admins", "devs"},
}
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.BearerAuth)
assert.True(t, service.Auth.BearerAuth.Enabled)
assert.Equal(t, []string{"admins", "devs"}, service.Auth.BearerAuth.DistributionGroups)
})
t.Run("with all auth types", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 443,
Domain: "myco.com",
Pin: "9999",
Password: "pass",
UserGroups: []string{"ops"},
}
service := req.ToService("acc", "peer", "full")
assert.Equal(t, "full.myco.com", service.Domain)
require.NotNil(t, service.Auth.PinAuth)
require.NotNil(t, service.Auth.PasswordAuth)
require.NotNil(t, service.Auth.BearerAuth)
})
}

View File

@@ -152,11 +152,6 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}
serviceMgr := s.ServiceManager()
srv.SetReverseProxyManager(serviceMgr)
if serviceMgr != nil {
serviceMgr.StartExposeReaper(context.Background())
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
@@ -170,7 +165,7 @@ func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())
})
return proxyService

View File

@@ -6,8 +6,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -108,13 +108,9 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
})
}
func (s *BaseServer) ServiceProxyController() proxy.Controller {
return Create(s, func() proxy.Controller {
controller, err := proxymanager.NewGRPCController(s.ReverseProxyGRPCServer(), s.Metrics().GetMeter())
if err != nil {
log.Fatalf("failed to create service proxy controller: %v", err)
}
return controller
func (s *BaseServer) ServiceProxyController() service.ProxyController {
return Create(s, func() service.ProxyController {
return nbreverseproxy.NewGRPCProxyController(s.ReverseProxyGRPCServer())
})
}

View File

@@ -200,11 +200,7 @@ func (s *BaseServer) ServiceManager() service.Manager {
func (s *BaseServer) ProxyManager() proxy.Manager {
return Create(s, func() proxy.Manager {
manager, err := proxymanager.NewManager(s.Store(), s.Metrics().GetMeter())
if err != nil {
log.Fatalf("failed to create proxy manager: %v", err)
}
return manager
return proxymanager.NewManager(s.Store())
})
}

View File

@@ -157,7 +157,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
// Eagerly create the gRPC server so that all AfterInit hooks are registered
// before we iterate them. Lazy creation after the loop would miss hooks
// registered during GRPCServer() construction (e.g., SetServiceManager).
// registered during GRPCServer() construction (e.g., SetProxyManager).
s.GRPCServer()
for _, fn := range s.afterInit {

View File

@@ -1,202 +0,0 @@
package grpc
import (
"fmt"
"net/netip"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache cache.DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &cache.DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &cache.DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
}
}
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}
func TestBuildJWTConfig_Audiences(t *testing.T) {
tests := []struct {
name string
authAudience string
cliAuthAudience string
expectedAudiences []string
expectedAudience string
}{
{
name: "only_auth_audience",
authAudience: "dashboard-aud",
cliAuthAudience: "",
expectedAudiences: []string{"dashboard-aud"},
expectedAudience: "dashboard-aud",
},
{
name: "both_audiences_different",
authAudience: "dashboard-aud",
cliAuthAudience: "cli-aud",
expectedAudiences: []string{"dashboard-aud", "cli-aud"},
expectedAudience: "cli-aud",
},
{
name: "both_audiences_same",
authAudience: "same-aud",
cliAuthAudience: "same-aud",
expectedAudiences: []string{"same-aud"},
expectedAudience: "same-aud",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
config := &nbconfig.HttpServerConfig{
AuthIssuer: "https://issuer.example.com",
AuthAudience: tc.authAudience,
CLIAuthAudience: tc.cliAuthAudience,
}
result := buildJWTConfig(config, nil)
assert.NotNil(t, result)
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
//nolint:staticcheck // SA1019: Testing backwards compatibility - Audience field must still be populated
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
})
}
}

View File

@@ -1,192 +0,0 @@
package grpc
import (
"context"
pb "github.com/golang/protobuf/proto" // nolint
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/proto"
internalStatus "github.com/netbirdio/netbird/shared/management/status"
)
// CreateExpose handles a peer request to create a new expose service.
func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
exposeReq := &proto.ExposeServiceRequest{}
peerKey, err := s.parseRequest(ctx, req, exposeReq)
if err != nil {
return nil, err
}
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
if err != nil {
return nil, err
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{
NamePrefix: exposeReq.NamePrefix,
Port: int(exposeReq.Port),
Protocol: exposeProtocolToString(exposeReq.Protocol),
Domain: exposeReq.Domain,
Pin: exposeReq.Pin,
Password: exposeReq.Password,
UserGroups: exposeReq.UserGroups,
})
if err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{
ServiceName: created.ServiceName,
ServiceUrl: created.ServiceURL,
Domain: created.Domain,
})
}
// RenewExpose extends the TTL of an active expose session.
func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
renewReq := &proto.RenewExposeRequest{}
peerKey, err := s.parseRequest(ctx, req, renewReq)
if err != nil {
return nil, err
}
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
if err != nil {
return nil, err
}
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.RenewExposeResponse{})
}
// StopExpose terminates an active expose session.
func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
stopReq := &proto.StopExposeRequest{}
peerKey, err := s.parseRequest(ctx, req, stopReq)
if err != nil {
return nil, err
}
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
if err != nil {
return nil, err
}
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.StopExposeResponse{})
}
func mapExposeError(ctx context.Context, err error) error {
s, ok := internalStatus.FromError(err)
if !ok {
log.WithContext(ctx).Errorf("expose service error: %v", err)
return status.Errorf(codes.Internal, "internal error")
}
switch s.Type() {
case internalStatus.InvalidArgument:
return status.Errorf(codes.InvalidArgument, "%s", s.Message)
case internalStatus.PermissionDenied:
return status.Errorf(codes.PermissionDenied, "%s", s.Message)
case internalStatus.NotFound:
return status.Errorf(codes.NotFound, "%s", s.Message)
case internalStatus.AlreadyExists:
return status.Errorf(codes.AlreadyExists, "%s", s.Message)
case internalStatus.PreconditionFailed:
return status.Errorf(codes.ResourceExhausted, "%s", s.Message)
default:
log.WithContext(ctx).Errorf("expose service error: %v", err)
return status.Errorf(codes.Internal, "internal error")
}
}
func (s *Server) encryptResponse(peerKey wgtypes.Key, msg pb.Message) (*proto.EncryptedMessage, error) {
wgKey, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "internal error")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, wgKey, msg)
if err != nil {
return nil, status.Errorf(codes.Internal, "encrypt response")
}
return &proto.EncryptedMessage{
WgPubKey: wgKey.PublicKey().String(),
Body: encryptedResp,
}, nil
}
func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key) (string, *nbpeer.Peer, error) {
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return "", nil, status.Errorf(codes.Internal, "lookup account for peer")
}
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
if err != nil {
return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return accountID, peer, nil
}
func (s *Server) getReverseProxyManager() rpservice.Manager {
s.reverseProxyMu.RLock()
defer s.reverseProxyMu.RUnlock()
return s.reverseProxyManager
}
// SetReverseProxyManager sets the reverse proxy manager on the server.
func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) {
s.reverseProxyMu.Lock()
defer s.reverseProxyMu.Unlock()
s.reverseProxyManager = mgr
}
func exposeProtocolToString(p proto.ExposeProtocol) string {
switch p {
case proto.ExposeProtocol_EXPOSE_HTTP:
return "http"
case proto.ExposeProtocol_EXPOSE_HTTPS:
return "https"
default:
return "http"
}
}

View File

@@ -1,276 +0,0 @@
package grpc
import (
"hash/fnv"
"math"
"math/rand"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/suite"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func testAdvancedCfg() *lfConfig {
return &lfConfig{
reconnThreshold: 50 * time.Millisecond,
baseBlockDuration: 100 * time.Millisecond,
reconnLimitForBan: 3,
metaChangeLimit: 2,
}
}
type LoginFilterTestSuite struct {
suite.Suite
filter *loginFilter
}
func (s *LoginFilterTestSuite) SetupTest() {
s.filter = newLoginFilterWithCfg(testAdvancedCfg())
}
func TestLoginFilterTestSuite(t *testing.T) {
suite.Run(t, new(LoginFilterTestSuite))
}
func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(1, s.filter.logged[pubKey].sessionCounter)
}
func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.False(s.filter.allowLogin(pubKey, meta))
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
baseBan := s.filter.cfg.baseBlockDuration
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(1, s.filter.logged[pubKey].banLevel)
firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond))
s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second)
s.filter.logged[pubKey].isBanned = false
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(2, s.filter.logged[pubKey].banLevel)
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
// nolint
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
}
func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
isBanned: true,
banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)),
}
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.False(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
currentHash: meta,
banLevel: 3,
lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration),
}
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(0, s.filter.logged[pubKey].banLevel)
}
func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() {
pubKey := "PUB_KEY_A"
limit := s.filter.cfg.metaChangeLimit
for i := range limit {
s.filter.addLogin(pubKey, uint64(i+1))
}
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter)
isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1))
s.False(isAllowed, "should block new meta hash after limit is reached")
}
func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() {
pubKey := "PUB_KEY_A"
meta1 := uint64(1)
meta2 := uint64(2)
meta3 := uint64(3)
s.filter.addLogin(pubKey, meta1)
s.filter.addLogin(pubKey, meta2)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter)
s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window")
s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second))
s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires")
s.filter.addLogin(pubKey, meta3)
s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset")
}
func BenchmarkHashingMethods(b *testing.B) {
meta := nbpeer.PeerSystemMeta{
WtVersion: "1.25.1",
OSVersion: "Ubuntu 22.04.3 LTS",
KernelVersion: "5.15.0-76-generic",
Hostname: "prod-server-database-01",
SystemSerialNumber: "PC-1234567890",
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
}
pubip := "8.8.8.8"
var resultString string
var resultUint uint64
b.Run("BuilderString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = builderString(meta, pubip)
}
})
b.Run("FnvHashToString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = fnvHashToString(meta, pubip)
}
})
b.Run("FnvHashToUint64 - used", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultUint = metaHash(meta, pubip)
}
})
_ = resultString
_ = resultUint
}
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
h := fnv.New64a()
if len(meta.NetworkAddresses) != 0 {
for _, na := range meta.NetworkAddresses {
h.Write([]byte(na.Mac))
}
}
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return strconv.FormatUint(h.Sum64(), 16)
}
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
mac := getMacAddress(meta.NetworkAddresses)
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
len(pubip) + len(mac) + 6
var b strings.Builder
b.Grow(estimatedSize)
b.WriteString(meta.WtVersion)
b.WriteByte('|')
b.WriteString(meta.OSVersion)
b.WriteByte('|')
b.WriteString(meta.KernelVersion)
b.WriteByte('|')
b.WriteString(meta.Hostname)
b.WriteByte('|')
b.WriteString(meta.SystemSerialNumber)
b.WriteByte('|')
b.WriteString(pubip)
return b.String()
}
func getMacAddress(nas []nbpeer.NetworkAddress) string {
if len(nas) == 0 {
return ""
}
macs := make([]string, 0, len(nas))
for _, na := range nas {
macs = append(macs, na.Mac)
}
return strings.Join(macs, "/")
}
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
filter := newLoginFilterWithCfg(testAdvancedCfg())
numKeys := 100000
pubKeys := make([]string, numKeys)
for i := range numKeys {
pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for pb.Next() {
key := pubKeys[r.Intn(numKeys)]
meta := r.Uint64()
if filter.allowLogin(key, meta) {
filter.addLogin(key, meta)
}
}
})
}

View File

@@ -59,6 +59,9 @@ type ProxyServiceServer struct {
// Map of connected proxies: proxy_id -> proxy connection
connectedProxies sync.Map
// Channel for broadcasting reverse proxy updates to all proxies
updatesChan chan *proto.ProxyMapping
// Manager for access logs
accessLogManager accesslogs.Manager
@@ -66,7 +69,7 @@ type ProxyServiceServer struct {
serviceManager rpservice.Manager
// ProxyController for service updates and cluster management
proxyController proxy.Controller
proxyController rpservice.ProxyController
// Manager for proxy connections
proxyManager proxy.Manager
@@ -102,7 +105,7 @@ type proxyConnection struct {
proxyID string
address string
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.GetMappingUpdateResponse
sendChan chan *proto.ProxyMapping
ctx context.Context
cancel context.CancelFunc
}
@@ -111,6 +114,7 @@ type proxyConnection struct {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
updatesChan: make(chan *proto.ProxyMapping, 100),
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
tokenStore: tokenStore,
@@ -165,11 +169,11 @@ func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel()
}
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
func (s *ProxyServiceServer) SetProxyManager(manager rpservice.Manager) {
s.serviceManager = manager
}
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
func (s *ProxyServiceServer) SetProxyController(proxyController rpservice.ProxyController) {
s.proxyController = proxyController
}
@@ -199,7 +203,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
proxyID: proxyID,
address: proxyAddress,
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
sendChan: make(chan *proto.ProxyMapping, 100),
ctx: connCtx,
cancel: cancel,
}
@@ -345,7 +349,7 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
for {
select {
case msg := <-conn.sendChan:
if err := conn.stream.Send(msg); err != nil {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil {
errChan <- err
return
}
@@ -396,7 +400,7 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
// Management should call this when services are created/updated/removed.
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
log.Debugf("Broadcasting service update to all connected proxy servers")
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
@@ -406,7 +410,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
}
select {
case conn.sendChan <- msg:
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID)
default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
}
@@ -455,12 +459,8 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
updateResponse := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{update},
}
if clusterAddr == "" {
s.SendServiceUpdate(updateResponse)
s.SendServiceUpdate(update)
return
}
@@ -479,7 +479,7 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
for _, proxyID := range proxyIDs {
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
msg := s.perProxyMessage(updateResponse, proxyID)
msg := s.perProxyMessage(update, proxyID)
if msg == nil {
continue
}
@@ -494,31 +494,23 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
}
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
// create/update operations. For delete operations the original message is
// returned unchanged because proxies do not need to authenticate for removal.
// Returns nil if token generation fails (the proxy should be skipped).
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, mapping := range update.Mapping {
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
resp = append(resp, mapping)
continue
}
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil
}
msg := shallowCloneMapping(mapping)
msg.AuthToken = token
resp = append(resp, msg)
func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping {
if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" {
return update
}
return &proto.GetMappingUpdateResponse{
Mapping: resp,
token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute)
if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil
}
msg := shallowCloneMapping(update)
msg.AuthToken = token
return msg
}
// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the
@@ -814,8 +806,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig {
// GetOIDCValidationConfig returns the OIDC configuration for token validation
// in the format needed by ToProtoMapping.
func (s *ProxyServiceServer) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return proxy.OIDCValidationConfig{
func (s *ProxyServiceServer) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
return rpservice.OIDCValidationConfig{
Issuer: s.oidcConfig.Issuer,
Audiences: []string{s.oidcConfig.Audience},
KeysLocation: s.oidcConfig.KeysLocation,

View File

@@ -1,98 +0,0 @@
package grpc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
}
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
ip := "192.168.1.1"
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
}
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure("192.168.1.1")
}
assert.True(t, l.isLimited("192.168.1.1"))
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
}
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
defer l.stop()
ip := "10.0.0.1"
// Exhaust burst
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
require.True(t, l.isLimited(ip))
// Wait for token replenishment
time.Sleep(50 * time.Millisecond)
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
}
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.mu.Lock()
require.Len(t, l.limiters, 1)
// Backdate the entry so it looks stale
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
l.mu.Unlock()
}
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.recordFailure("10.0.0.2")
l.mu.Lock()
// Only backdate one entry
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
assert.Contains(t, l.limiters, "10.0.0.2")
l.mu.Unlock()
}

View File

@@ -1,399 +0,0 @@
package grpc
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/types"
)
type mockReverseProxyManager struct {
proxiesByAccount map[string][]*service.Service
err error
}
func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
return nil
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
if m.err != nil {
return nil, m.err
}
return m.proxiesByAccount[accountID], nil
}
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
return []*service.Service{}, nil
}
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error {
return nil
}
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
return nil
}
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return &service.ExposeServiceResponse{}, nil
}
func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
type mockUsersManager struct {
users map[string]*types.User
err error
}
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
if m.err != nil {
return nil, m.err
}
user, ok := m.users[userID]
if !ok {
return nil, errors.New("user not found")
}
return user, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
domain string
userID string
proxiesByAccount map[string][]*service.Service
users map[string]*types.User
proxyErr error
userErr error
expectErr bool
expectErrMsg string
}{
{
name: "user not found",
domain: "app.example.com",
userID: "unknown-user",
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
users: map[string]*types.User{},
expectErr: true,
expectErrMsg: "user not found",
},
{
name: "proxy not found in user's account",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*service.Service{},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "service not found",
},
{
name: "proxy exists in different account - not accessible",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*service.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "service not found",
},
{
name: "no bearer auth configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth disabled - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{Enabled: false},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "user not in allowed groups",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
},
expectErr: true,
expectErrMsg: "not in allowed groups",
},
{
name: "user in one of the allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
},
expectErr: false,
},
{
name: "user in all allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
},
expectErr: false,
},
{
name: "proxy manager error",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: nil,
proxyErr: errors.New("database error"),
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "get account services",
},
{
name: "multiple proxies in account - finds correct one",
domain: "app2.example.com",
userID: "user1",
proxiesByAccount: map[string][]*service.Service{
"account1": {
{Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"},
},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.proxyErr,
},
usersManager: &mockUsersManager{
users: tt.users,
err: tt.userErr,
},
}
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
if tt.expectErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.expectErrMsg)
} else {
require.NoError(t, err)
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct {
name string
accountID string
domain string
proxiesByAccount map[string][]*service.Service
err error
expectProxy bool
expectErr bool
}{
{
name: "proxy found",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*service.Service{
"account1": {
{Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"},
},
},
expectProxy: true,
expectErr: false,
},
{
name: "proxy not found in account",
accountID: "account1",
domain: "unknown.example.com",
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
expectProxy: false,
expectErr: true,
},
{
name: "empty proxy list for account",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*service.Service{},
expectProxy: false,
expectErr: true,
},
{
name: "manager error",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: nil,
err: errors.New("database error"),
expectProxy: false,
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.err,
},
}
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
if tt.expectErr {
require.Error(t, err)
assert.Nil(t, proxy)
} else {
require.NoError(t, err)
require.NotNil(t, proxy)
assert.Equal(t, tt.domain, proxy.Domain)
}
})
}
}

View File

@@ -1,295 +0,0 @@
package grpc
import (
"context"
"crypto/rand"
"encoding/base64"
"strings"
"testing"
"time"
"sync"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/management/proto"
)
type testProxyController struct {
mu sync.Mutex
clusterProxies map[string]map[string]struct{}
}
func newTestProxyController() *testProxyController {
return &testProxyController{
clusterProxies: make(map[string]map[string]struct{}),
}
}
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
}
func (c *testProxyController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return proxy.OIDCValidationConfig{}
}
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, clusterAddr, proxyID string) error {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.clusterProxies[clusterAddr]; !ok {
c.clusterProxies[clusterAddr] = make(map[string]struct{})
}
c.clusterProxies[clusterAddr][proxyID] = struct{}{}
return nil
}
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clusterAddr, proxyID string) error {
c.mu.Lock()
defer c.mu.Unlock()
if proxies, ok := c.clusterProxies[clusterAddr]; ok {
delete(proxies, proxyID)
}
return nil
}
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
c.mu.Lock()
defer c.mu.Unlock()
proxies, ok := c.clusterProxies[clusterAddr]
if !ok {
return nil
}
result := make([]string, 0, len(proxies))
for id := range proxies {
result = append(result, id)
}
return result
}
// registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
ch := make(chan *proto.GetMappingUpdateResponse, 10)
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
sendChan: ch,
}
s.connectedProxies.Store(proxyID, conn)
_ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID)
return ch
}
func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
select {
case msg := <-ch:
return msg
case <-time.After(time.Second):
return nil
}
}
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
const numProxies = 3
channels := make([]chan *proto.GetMappingUpdateResponse, numProxies)
for i := range numProxies {
id := "proxy-" + string(rune('a'+i))
channels[i] = registerFakeProxy(s, id, cluster)
}
mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
Path: []*proto.PathMapping{
{Path: "/", Target: "http://10.0.0.1:8080/"},
},
}
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
tokens := make([]string, numProxies)
for i, ch := range channels {
resp := drainChannel(ch)
require.NotNil(t, resp, "proxy %d should receive a message", i)
require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i)
msg := resp.Mapping[0]
assert.Equal(t, mapping.Domain, msg.Domain)
assert.Equal(t, mapping.Id, msg.Id)
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
tokens[i] = msg.AuthToken
}
// All tokens must be unique
tokenSet := make(map[string]struct{})
for i, tok := range tokens {
_, exists := tokenSet[tok]
assert.False(t, exists, "proxy %d got duplicate token", i)
tokenSet[tok] = struct{}{}
}
// Each token must be independently consumable
for i, tok := range tokens {
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
assert.NoError(t, err, "proxy %d token should validate successfully", i)
}
}
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
ch1 := registerFakeProxy(s, "proxy-a", cluster)
ch2 := registerFakeProxy(s, "proxy-b", cluster)
mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2)
require.NotNil(t, resp1)
require.NotNil(t, resp2)
require.Len(t, resp1.Mapping, 1)
require.Len(t, resp2.Mapping, 1)
// Delete operations should not generate tokens
assert.Empty(t, resp1.Mapping[0].AuthToken)
assert.Empty(t, resp2.Mapping[0].AuthToken)
}
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
update := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{mapping},
}
s.SendServiceUpdate(update)
resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2)
require.NotNil(t, resp1)
require.NotNil(t, resp2)
require.Len(t, resp1.Mapping, 1)
require.Len(t, resp2.Mapping, 1)
msg1 := resp1.Mapping[0]
msg2 := resp2.Mapping[0]
assert.NotEmpty(t, msg1.AuthToken)
assert.NotEmpty(t, msg2.AuthToken)
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
// Both tokens should validate
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
}
// generateState creates a state using the same format as GetOIDCURL.
func generateState(s *ProxyServiceServer, redirectURL string) string {
nonce := make([]byte, 16)
_, _ = rand.Read(nonce)
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
payload := redirectURL + "|" + nonceB64
hmacSum := s.generateHMAC(payload)
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
}
func TestOAuthState_NeverTheSame(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
redirectURL := "https://app.example.com/callback"
// Generate 100 states for the same redirect URL
states := make(map[string]bool)
for i := 0; i < 100; i++ {
state := generateState(s, redirectURL)
// State must have 3 parts: base64(url)|nonce|hmac
parts := strings.Split(state, "|")
require.Equal(t, 3, len(parts), "state must have 3 parts")
// State must be unique
require.False(t, states[state], "state %d is a duplicate", i)
states[state] = true
}
}
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Old format had only 2 parts: base64(url)|hmac
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("base64url|hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state format")
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Store with tampered HMAC
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature")
}

View File

@@ -26,7 +26,6 @@ import (
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/job"
@@ -81,9 +80,6 @@ type Server struct {
syncSem atomic.Int32
syncLimEnabled bool
syncLim int32
reverseProxyManager rpservice.Manager
reverseProxyMu sync.RWMutex
}
// NewServer creates a new Management server

View File

@@ -1,108 +0,0 @@
package grpc
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testingServerKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testingClientKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testCases := []struct {
name string
inputFlow *config.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
expectedComparisonFunc require.ComparisonAssertionFunc
expectedComparisonMSG string
}{
{
name: "Testing No Device Flow Config",
inputFlow: nil,
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Full Device Flow Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.NoError,
expectedErrMSG: "should not return error",
expectedComparisonFunc: require.Equal,
expectedComparisonMSG: "should match",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &Server{
secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey},
config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
key, err := mgmtServer.secretsManager.GetWGKey()
require.NoError(t, err, "should be able to get server key")
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message)
require.NoError(t, err, "should be able to encrypt message")
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
context.TODO(),
&mgmtProto.EncryptedMessage{
WgPubKey: testingClientKey.PublicKey().String(),
Body: encryptedMSG,
},
)
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
if testCase.expectedComparisonFunc != nil {
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
require.NoError(t, err, "should be able to decrypt")
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
}
})
}
}

View File

@@ -1,250 +0,0 @@
package grpc
import (
"context"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"hash"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
var TurnTestHost = &config.Host{
Proto: config.UDP,
URI: "turn:turn.netbird.io:77777",
Username: "username",
Password: "",
}
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
peersManager := update_channel.NewPeersUpdateManager(nil)
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
turnCredentials, err := tested.GenerateTurnToken()
require.NoError(t, err)
if turnCredentials.Payload == "" {
t.Errorf("expected generated TURN username not to be empty, got empty")
}
if turnCredentials.Signature == "" {
t.Errorf("expected generated TURN password not to be empty, got empty")
}
validateMAC(t, sha1.New, turnCredentials.Payload, turnCredentials.Signature, []byte(secret))
relayCredentials, err := tested.GenerateRelayToken()
require.NoError(t, err)
if relayCredentials.Payload == "" {
t.Errorf("expected generated relay payload not to be empty, got empty")
}
if relayCredentials.Signature == "" {
t.Errorf("expected generated relay signature not to be empty, got empty")
}
hashedSecret := sha256.Sum256([]byte(secret))
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
}
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
ttl := util.Duration{Duration: 2 * time.Second}
secret := "some_secret"
peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer)
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock()
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tested.SetupRefresh(ctx, "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in the turn cancel map, got not present")
}
if _, ok := tested.relayCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in the relay cancel map, got not present")
}
var updates []*network_map.UpdateMessage
loop:
for timeout := time.After(5 * time.Second); ; {
select {
case update := <-updateChannel:
updates = append(updates, update)
case <-timeout:
break loop
}
if len(updates) >= 2 {
break loop
}
}
if len(updates) < 2 {
t.Errorf("expecting at least 2 peer credentials updates, got %v", len(updates))
}
var turnUpdates, relayUpdates int
var firstTurnUpdate, secondTurnUpdate *proto.ProtectedHostConfig
var firstRelayUpdate, secondRelayUpdate *proto.RelayConfig
for _, update := range updates {
if turns := update.Update.GetNetbirdConfig().GetTurns(); len(turns) > 0 {
turnUpdates++
if turnUpdates == 1 {
firstTurnUpdate = turns[0]
} else {
secondTurnUpdate = turns[0]
}
}
if relay := update.Update.GetNetbirdConfig().GetRelay(); relay != nil {
// avoid updating on turn updates since they also send relay credentials
if update.Update.GetNetbirdConfig().GetTurns() == nil {
relayUpdates++
if relayUpdates == 1 {
firstRelayUpdate = relay
} else {
secondRelayUpdate = relay
}
}
}
}
if turnUpdates < 1 {
t.Errorf("expecting at least 1 TURN credential update, got %v", turnUpdates)
}
if relayUpdates < 1 {
t.Errorf("expecting at least 1 relay credential update, got %v", relayUpdates)
}
if firstTurnUpdate != nil && secondTurnUpdate != nil {
if firstTurnUpdate.Password == secondTurnUpdate.Password {
t.Errorf("expecting first TURN credential update password %v to be different from second, got equal", firstTurnUpdate.Password)
}
}
if firstRelayUpdate != nil && secondRelayUpdate != nil {
if firstRelayUpdate.TokenSignature == secondRelayUpdate.TokenSignature {
t.Errorf("expecting first relay credential update signature %v to be different from second, got equal", firstRelayUpdate.TokenSignature)
}
}
}
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer"
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
tested.SetupRefresh(context.Background(), "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in turn cancel map, got not present")
}
if _, ok := tested.relayCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in relay cancel map, got not present")
}
tested.CancelRefresh(peer)
if _, ok := tested.turnCancelMap[peer]; ok {
t.Errorf("expecting peer to be not present in turn cancel map, got present")
}
if _, ok := tested.relayCancelMap[peer]; ok {
t.Errorf("expecting peer to be not present in relay cancel map, got present")
}
}
func validateMAC(t *testing.T, algo func() hash.Hash, username string, actualMAC string, key []byte) {
t.Helper()
mac := hmac.New(algo, key)
_, err := mac.Write([]byte(username))
if err != nil {
t.Fatal(err)
}
expectedMAC := mac.Sum(nil)
decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC)
if err != nil {
t.Fatal(err)
}
equal := hmac.Equal(decodedMAC, expectedMAC)
if !equal {
t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC)
}
}

View File

@@ -1,587 +0,0 @@
package grpc
import (
"testing"
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/shared/management/proto"
)
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
shouldSend := debouncer.ProcessUpdate(update)
if !shouldSend {
t.Error("First update should be sent immediately")
}
if debouncer.TimerChannel() == nil {
t.Error("Timer should be started after first update")
}
}
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update should be sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Rapid subsequent updates should be coalesced
if debouncer.ProcessUpdate(update2) {
t.Error("Second rapid update should not be sent immediately")
}
if debouncer.ProcessUpdate(update3) {
t.Error("Third rapid update should not be sent immediately")
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Send second update within debounce period
debouncer.ProcessUpdate(update2)
// Wait for timer
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update2 {
t.Error("Should get the last update")
}
if pendingUpdates[0] == update1 {
t.Error("Should not get the first update")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Wait a bit, but not the full debounce period
time.Sleep(30 * time.Millisecond)
// Send second update - should reset timer
debouncer.ProcessUpdate(update2)
// Wait a bit more
time.Sleep(30 * time.Millisecond)
// Send third update - should reset timer again
debouncer.ProcessUpdate(update3)
// Now wait for the timer (should fire after last update's reset)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
// Timer should be restarted since there was a pending update
if debouncer.TimerChannel() == nil {
t.Error("Timer should be restarted after sending pending update")
}
case <-time.After(150 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
debouncer.ProcessUpdate(update1)
// Second update coalesced
debouncer.ProcessUpdate(update2)
// Wait for timer to expire
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have pending update")
}
// After sending pending update, timer is restarted, so next update is NOT immediate
if debouncer.ProcessUpdate(update3) {
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
}
// Wait for the restarted timer and verify update3 is pending
select {
case <-debouncer.TimerChannel():
finalUpdates := debouncer.GetPendingUpdates()
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
t.Error("Should get update3 as pending")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired for restarted timer")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send update to start timer
debouncer.ProcessUpdate(update)
// Stop should clean up
debouncer.Stop()
// Multiple stops should be safe
debouncer.Stop()
}
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate high-frequency updates
var lastUpdate *network_map.UpdateMessage
sentImmediately := 0
for i := 0; i < 100; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
lastUpdate = update
if debouncer.ProcessUpdate(update) {
sentImmediately++
}
time.Sleep(1 * time.Millisecond) // Very rapid updates
}
// Only first update should be sent immediately
if sentImmediately != 1 {
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != lastUpdate {
t.Error("Should get the very last update")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
// Wait for timer to expire with no additional updates (true quiet period)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates")
}
// After true quiet period, timer should be cleared
if debouncer.TimerChannel() != nil {
t.Error("Timer should be cleared after quiet period")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
updates := make([]*network_map.UpdateMessage, 5)
for i := range updates {
updates[i] = &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
}
// First update sent immediately
debouncer.ProcessUpdate(updates[0])
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
debouncer.ProcessUpdate(updates[1])
debouncer.ProcessUpdate(updates[2])
debouncer.ProcessUpdate(updates[3])
debouncer.ProcessUpdate(updates[4])
// Wait for debounce
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
}
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Wait for timer without sending any more updates (true quiet period)
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates during quiet period")
}
// After true quiet period, next update should be sent immediately
if !debouncer.ProcessUpdate(update2) {
t.Error("Update after true quiet period should be sent immediately")
}
}
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate continuous high-frequency updates
for i := 0; i < 10; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
if i == 0 {
// First one sent immediately
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
} else {
// All others should be coalesced (not sent immediately)
if debouncer.ProcessUpdate(update) {
t.Errorf("Update %d should not be sent immediately", i)
}
}
// Wait a bit but send next update before debounce expires
time.Sleep(20 * time.Millisecond)
}
// Now wait for final debounce
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have the last update pending")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
tokenUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate)
// Send multiple control config updates - they should all be queued
debouncer.ProcessUpdate(tokenUpdate1)
debouncer.ProcessUpdate(tokenUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get both control config updates
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
}
// Control configs should come first
if pendingUpdates[0] != tokenUpdate1 {
t.Error("First pending update should be tokenUpdate1")
}
if pendingUpdates[1] != tokenUpdate2 {
t.Error("Second pending update should be tokenUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
netmapUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate1)
// Send token update and network map update
debouncer.ProcessUpdate(tokenUpdate)
debouncer.ProcessUpdate(netmapUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get 2 updates in order: token, then network map
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
}
// Token update should come first (preserves order)
if pendingUpdates[0] != tokenUpdate {
t.Error("First pending update should be tokenUpdate")
}
// Network map update should come second
if pendingUpdates[1] != netmapUpdate2 {
t.Error("Second pending update should be netmapUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate: 50 network maps -> 1 control config -> 50 network maps
// Expected result: 3 messages (netmap, controlConfig, netmap)
// Send first network map immediately
firstNetmap := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
MessageType: network_map.MessageTypeNetworkMap,
}
if !debouncer.ProcessUpdate(firstNetmap) {
t.Error("First update should be sent immediately")
}
// Send 49 more network maps (will be coalesced to last one)
var lastNetmapBatch1 *network_map.UpdateMessage
for i := 1; i < 50; i++ {
lastNetmapBatch1 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch1)
}
// Send 1 control config
controlConfig := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
debouncer.ProcessUpdate(controlConfig)
// Send 50 more network maps (will be coalesced to last one)
var lastNetmapBatch2 *network_map.UpdateMessage
for i := 50; i < 100; i++ {
lastNetmapBatch2 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch2)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get exactly 3 updates: netmap, controlConfig, netmap
if len(pendingUpdates) != 3 {
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
}
// First should be the last netmap from batch 1
if pendingUpdates[0] != lastNetmapBatch1 {
t.Error("First pending update should be last netmap from batch 1")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
// Second should be the control config
if pendingUpdates[1] != controlConfig {
t.Error("Second pending update should be control config")
}
// Third should be the last netmap from batch 2
if pendingUpdates[2] != lastNetmapBatch2 {
t.Error("Third pending update should be last netmap from batch 2")
}
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}

View File

@@ -1,348 +0,0 @@
//go:build integration
package grpc
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
type validateSessionTestSetup struct {
proxyService *ProxyServiceServer
store store.Store
cleanup func()
}
func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
t.Helper()
ctx := context.Background()
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err)
serviceManager := &testValidateSessionServiceManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore}
proxyManager := &testValidateSessionProxyManager{}
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)
return &validateSessionTestSetup{
proxyService: proxyService,
store: testStore,
cleanup: storeCleanup,
}
}
func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
pubKey, privKey := generateSessionKeyPair(t)
testProxy := &service.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
},
},
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &service.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
},
}
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
}
func generateSessionKeyPair(t *testing.T) (string, string) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv)
}
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
require.NoError(t, err)
return token
}
func TestValidateSession_UserAllowed(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.True(t, resp.Valid, "User should be allowed access")
assert.Equal(t, "allowedUserId", resp.UserId)
assert.Empty(t, resp.DeniedReason)
}
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "restricted-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User not in group should be denied")
assert.Equal(t, "not_in_group", resp.DeniedReason)
assert.Equal(t, "nonGroupUserId", resp.UserId)
}
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User in different account should be denied")
assert.Equal(t, "account_mismatch", resp.DeniedReason)
}
func TestValidateSession_UserNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Non-existent user should be denied")
assert.Equal(t, "user_not_found", resp.DeniedReason)
}
func TestValidateSession_ProxyNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "unknown-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Unknown proxy should be denied")
assert.Equal(t, "service_not_found", resp.DeniedReason)
}
func TestValidateSession_InvalidToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: "invalid-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Invalid token should be denied")
assert.Equal(t, "invalid_token", resp.DeniedReason)
}
func TestValidateSession_MissingDomain(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
SessionToken: "some-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
func TestValidateSession_MissingToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
type testValidateSessionServiceManager struct {
store store.Store
}
func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil
}
func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
func (m *testValidateSessionServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
type testValidateSessionUsersManager struct {
store store.Store
}
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
}

View File

@@ -376,7 +376,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID)
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err
}
@@ -493,21 +492,6 @@ func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Con
}
}
func (am *DefaultAccountManager) handlePeerExposeSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
oldEnabled := oldSettings.PeerExposeEnabled
newEnabled := newSettings.PeerExposeEnabled
if oldEnabled == newEnabled {
return
}
event := activity.AccountPeerExposeEnabled
if !newEnabled {
event = activity.AccountPeerExposeDisabled
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
@@ -730,11 +714,6 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
}
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
}
for _, otherUser := range account.Users {
if otherUser.Id == userID {
continue

View File

@@ -1,7 +1,5 @@
package account
//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"net"
@@ -63,11 +61,11 @@ type Manager interface {
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
@@ -28,13 +27,10 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
@@ -1806,12 +1802,12 @@ func TestAccount_Copy(t *testing.T) {
Address: "172.12.6.1/24",
},
},
Services: []*service.Service{
Services: []*reverseproxy.Service{
{
ID: "service1",
Name: "test-service",
AccountID: "account1",
Targets: []*service.Target{},
Targets: []*reverseproxy.Target{},
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
@@ -3116,12 +3112,6 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
proxyManager := proxy.NewMockManager(ctrl)
proxyManager.EXPECT().
CleanupStale(gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
@@ -3132,12 +3122,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
return manager, updateManager, nil
}

View File

@@ -208,18 +208,6 @@ const (
ServiceUpdated Activity = 109
ServiceDeleted Activity = 110
// PeerServiceExposed indicates that a peer exposed a service via the reverse proxy
PeerServiceExposed Activity = 111
// PeerServiceUnexposed indicates that a peer-exposed service was removed
PeerServiceUnexposed Activity = 112
// PeerServiceExposeExpired indicates that a peer-exposed service was removed due to TTL expiration
PeerServiceExposeExpired Activity = 113
// AccountPeerExposeEnabled indicates that a user enabled peer expose for the account
AccountPeerExposeEnabled Activity = 114
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
AccountPeerExposeDisabled Activity = 115
AccountDeleted Activity = 99999
)
@@ -357,13 +345,6 @@ var activityMap = map[Activity]Code{
ServiceCreated: {"Service created", "service.create"},
ServiceUpdated: {"Service updated", "service.update"},
ServiceDeleted: {"Service deleted", "service.delete"},
PeerServiceExposed: {"Peer exposed service", "service.peer.expose"},
PeerServiceUnexposed: {"Peer unexposed service", "service.peer.unexpose"},
PeerServiceExposeExpired: {"Peer exposed service expired", "service.peer.expose.expire"},
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
}
// StringCode returns a string code of the activity

View File

@@ -249,15 +249,7 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
switch storeEngine {
case types.SqliteStoreEngine:
dbFile := eventSinkDB
if envFile, ok := os.LookupEnv("NB_ACTIVITY_EVENT_SQLITE_FILE"); ok && envFile != "" {
dbFile = envFile
}
connStr := dbFile
if !filepath.IsAbs(dbFile) {
connStr = filepath.Join(dataDir, dbFile)
}
dialector = sqlite.Open(connStr)
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB))
case types.PostgresStoreEngine:
dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {

View File

@@ -425,11 +425,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
var groupIDsToDelete []string
var deletedGroups []*types.Group
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
@@ -438,7 +433,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
continue
}
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
@@ -626,7 +621,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
return nil
}
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error {
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
@@ -646,10 +641,6 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
return &GroupLinkError{"network resource", group.Resources[0].ID}
}
if slices.Contains(flowGroups, group.ID) {
return &GroupLinkError{"settings", "traffic event logging"}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}

View File

@@ -12,7 +12,6 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -27,7 +26,6 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
@@ -286,67 +284,6 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
}
}
func TestDefaultAccountManager_DeleteGroupLinkedToFlowGroup(t *testing.T) {
am, _, err := createManager(t)
require.NoError(t, err)
ctrl := gomock.NewController(t)
settingsMock := settings.NewMockManager(ctrl)
settingsMock.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{FlowGroups: []string{"grp-for-flow"}}, nil).
AnyTimes()
settingsMock.EXPECT().
UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(false, nil).
AnyTimes()
am.settingsManager = settingsMock
_, account, err := initTestGroupAccount(am)
require.NoError(t, err)
grp := &types.Group{
ID: "grp-for-flow",
AccountID: account.Id,
Name: "Group for flow",
Issued: types.GroupIssuedAPI,
Peers: make([]string, 0),
}
require.NoError(t, am.CreateGroup(context.Background(), account.Id, groupAdminUserID, grp))
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, "grp-for-flow")
require.Error(t, err)
var gErr *GroupLinkError
require.ErrorAs(t, err, &gErr)
assert.Equal(t, "settings", gErr.Resource)
assert.Equal(t, "traffic event logging", gErr.Name)
group, err := am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
require.NoError(t, err)
assert.NotNil(t, group)
regularGrp := &types.Group{
ID: "grp-regular",
AccountID: account.Id,
Name: "Regular group",
Issued: types.GroupIssuedAPI,
Peers: make([]string, 0),
}
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, regularGrp)
require.NoError(t, err)
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, []string{"grp-for-flow", "grp-regular"})
require.Error(t, err)
group, err = am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
require.NoError(t, err)
assert.NotNil(t, group)
_, err = am.GetGroup(context.Background(), account.Id, "grp-regular", groupAdminUserID)
assert.Error(t, err)
}
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) {
accountID := "testingAcc"
domain := "example.com"

View File

@@ -168,10 +168,6 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
if req.Settings.PeerExposeEnabled && len(req.Settings.PeerExposeGroups) == 0 {
return nil, status.Errorf(status.InvalidArgument, "peer expose requires at least one group")
}
returnSettings := &types.Settings{
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
@@ -179,9 +175,6 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
PeerExposeEnabled: req.Settings.PeerExposeEnabled,
PeerExposeGroups: req.Settings.PeerExposeGroups,
}
if req.Settings.Extra != nil {
@@ -343,8 +336,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
JwtAllowGroups: &jwtAllowGroups,
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
PeerExposeEnabled: settings.PeerExposeEnabled,
PeerExposeGroups: settings.PeerExposeGroups,
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain,
AutoUpdateVersion: &settings.AutoUpdateVersion,

View File

@@ -209,10 +209,9 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcConfig,
nil,
usersManager,
nil,
)
proxyService.SetServiceManager(&testServiceManager{store: testStore})
proxyService.SetProxyManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService, nil)
@@ -241,12 +240,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
testProxy := &service.Service{
testProxy := &reverseproxy.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Targets: []*service.Target{{
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
@@ -256,8 +255,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true,
}},
Enabled: true,
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
@@ -267,12 +266,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &service.Service{
restrictedProxy := &reverseproxy.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Targets: []*service.Target{{
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
@@ -282,8 +281,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true,
}},
Enabled: true,
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"restrictedGroupId"},
},
@@ -293,12 +292,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
}
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
noAuthProxy := &service.Service{
noAuthProxy := &reverseproxy.Service{
ID: "noAuthProxyId",
AccountID: "testAccountId",
Name: "No Auth Proxy",
Domain: "no-auth-proxy.example.com",
Targets: []*service.Target{{
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
@@ -308,8 +307,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true,
}},
Enabled: true,
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: false,
},
},
@@ -359,23 +358,19 @@ type testServiceManager struct {
store store.Store
}
func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
return nil
}
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
@@ -387,7 +382,7 @@ func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ stri
return nil
}
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
return nil
}
@@ -399,15 +394,15 @@ func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error
return nil
}
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
@@ -415,20 +410,6 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri
return "", nil
}
func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return nil, nil
}
func (m *testServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
t.Helper()

View File

@@ -9,8 +9,6 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/management-integrations/integrations"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
@@ -98,19 +96,11 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyMgr := proxymanager.NewManager(store)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err)
}
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager)
proxyServiceServer.SetServiceManager(serviceManager)
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
proxyServiceServer.SetProxyManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked

View File

@@ -52,7 +52,7 @@ type EmbeddedIdPConfig struct {
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
type EmbeddedStorageConfig struct {
// Type is the storage type: "sqlite3" (default) or "postgres"
// Type is the storage type (currently only "sqlite3" is supported)
Type string
// Config contains type-specific configuration
Config EmbeddedStorageTypeConfig
@@ -62,8 +62,6 @@ type EmbeddedStorageConfig struct {
type EmbeddedStorageTypeConfig struct {
// File is the path to the SQLite database file (for sqlite3 type)
File string
// DSN is the connection string for postgres
DSN string
}
// OwnerConfig represents the initial owner/admin user for the embedded IdP.
@@ -76,22 +74,6 @@ type OwnerConfig struct {
Username string
}
// buildIdpStorageConfig builds the Dex storage config map based on the storage type.
func buildIdpStorageConfig(storageType string, cfg EmbeddedStorageTypeConfig) (map[string]interface{}, error) {
switch storageType {
case "sqlite3":
return map[string]interface{}{
"file": cfg.File,
}, nil
case "postgres":
return map[string]interface{}{
"dsn": cfg.DSN,
}, nil
default:
return nil, fmt.Errorf("unsupported IdP storage type: %s", storageType)
}
}
// ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig.
func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
if c.Issuer == "" {
@@ -103,14 +85,6 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" {
return nil, fmt.Errorf("storage file is required for sqlite3")
}
if c.Storage.Type == "postgres" && c.Storage.Config.DSN == "" {
return nil, fmt.Errorf("storage DSN is required for postgres")
}
storageConfig, err := buildIdpStorageConfig(c.Storage.Type, c.Storage.Config)
if err != nil {
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
}
// Build CLI redirect URIs including the device callback (both relative and absolute)
cliRedirectURIs := c.CLIRedirectURIs
@@ -126,8 +100,10 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
cfg := &dex.YAMLConfig{
Issuer: c.Issuer,
Storage: dex.Storage{
Type: c.Storage.Type,
Config: storageConfig,
Type: c.Storage.Type,
Config: map[string]interface{}{
"file": c.Storage.Config.File,
},
},
Web: dex.Web{
AllowedOrigins: []string{"*"},

View File

@@ -14,7 +14,6 @@ import (
"github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/idp/dex"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/types"
@@ -52,7 +51,6 @@ type properties map[string]interface{}
type DataSource interface {
GetAllAccounts(ctx context.Context) []*types.Account
GetStoreEngine() types.Engine
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
}
// ConnManager peer connection manager that holds state for current active connections
@@ -212,17 +210,6 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
rosenpassEnabled int
localUsers int
idpUsers int
embeddedIdpTypes map[string]int
services int
servicesEnabled int
servicesTargets int
servicesStatusActive int
servicesStatusPending int
servicesStatusError int
servicesTargetType map[string]int
servicesAuthPassword int
servicesAuthPin int
servicesAuthOIDC int
)
start := time.Now()
metricsProperties := make(properties)
@@ -231,14 +218,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
rulesProtocol = make(map[string]int)
rulesDirection = make(map[string]int)
activeUsersLastDay = make(map[string]struct{})
embeddedIdpTypes = make(map[string]int)
servicesTargetType = make(map[string]int)
uptime = time.Since(w.startupTime).Seconds()
connections := w.connManager.GetAllConnectedPeers()
version = nbversion.NetbirdVersion()
customDomains, customDomainsValidated, _ := w.dataSource.GetCustomDomainsCounts(ctx)
for _, account := range w.dataSource.GetAllAccounts(ctx) {
accounts++
@@ -295,8 +278,6 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
} else {
idpUsers++
}
idpType := extractIdpType(idpID)
embeddedIdpTypes[idpType]++
}
}
}
@@ -350,37 +331,6 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
peerActiveVersions = append(peerActiveVersions, peer.Meta.WtVersion)
}
}
for _, service := range account.Services {
services++
if service.Enabled {
servicesEnabled++
}
servicesTargets += len(service.Targets)
switch rpservice.Status(service.Meta.Status) {
case rpservice.StatusActive:
servicesStatusActive++
case rpservice.StatusPending:
servicesStatusPending++
case rpservice.StatusError, rpservice.StatusCertificateFailed, rpservice.StatusTunnelNotCreated:
servicesStatusError++
}
for _, target := range service.Targets {
servicesTargetType[target.TargetType]++
}
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled {
servicesAuthPassword++
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled {
servicesAuthPin++
}
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
servicesAuthOIDC++
}
}
}
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
@@ -419,27 +369,6 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
metricsProperties["local_users_count"] = localUsers
metricsProperties["idp_users_count"] = idpUsers
metricsProperties["embedded_idp_count"] = len(embeddedIdpTypes)
metricsProperties["services"] = services
metricsProperties["services_enabled"] = servicesEnabled
metricsProperties["services_targets"] = servicesTargets
metricsProperties["services_status_active"] = servicesStatusActive
metricsProperties["services_status_pending"] = servicesStatusPending
metricsProperties["services_status_error"] = servicesStatusError
metricsProperties["services_auth_password"] = servicesAuthPassword
metricsProperties["services_auth_pin"] = servicesAuthPin
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
metricsProperties["custom_domains"] = customDomains
metricsProperties["custom_domains_validated"] = customDomainsValidated
for targetType, count := range servicesTargetType {
metricsProperties["services_target_type_"+targetType] = count
}
for idpType, count := range embeddedIdpTypes {
metricsProperties["embedded_idp_users_"+idpType] = count
}
for protocol, count := range rulesProtocol {
metricsProperties["rules_protocol_"+protocol] = count
@@ -527,20 +456,6 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
return req, cancel, nil
}
// extractIdpType extracts the IdP type from a Dex connector ID.
// Connector IDs are formatted as "<type>-<xid>" (e.g., "okta-abc123", "zitadel-xyz").
// Returns the type prefix, or "oidc" if no known prefix is found.
func extractIdpType(connectorID string) string {
if connectorID == "local" {
return "local"
}
idx := strings.LastIndex(connectorID, "-")
if idx <= 0 {
return "oidc"
}
return strings.ToLower(connectorID[:idx])
}
func getMinMaxVersion(inputList []string) (string, string) {
versions := make([]*version.Version, 0)

View File

@@ -6,7 +6,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/idp/dex"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -28,8 +27,7 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
localUserID := dex.EncodeDexUserID("10", "local")
idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0")
oidcUserID := dex.EncodeDexUserID("30", "d6jvvp69kmnc73c9pl40")
idpUserID := dex.EncodeDexUserID("20", "zitadel")
return []*types.Account{
{
Id: "1",
@@ -117,31 +115,6 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
},
},
},
Services: []*rpservice.Service{
{
ID: "svc1",
Enabled: true,
Targets: []*rpservice.Target{
{TargetType: "peer"},
{TargetType: "host"},
},
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
},
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
},
{
ID: "svc2",
Enabled: false,
Targets: []*rpservice.Target{
{TargetType: "domain"},
},
Auth: rpservice.AuthConfig{
BearerAuth: &rpservice.BearerAuthConfig{Enabled: true},
},
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
},
},
},
{
Id: "2",
@@ -207,13 +180,6 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
"1": {},
},
},
oidcUserID: {
Id: oidcUserID,
IsServiceUser: false,
PATs: map[string]*types.PersonalAccessToken{
"1": {},
},
},
},
Networks: []*networkTypes.Network{
{
@@ -249,11 +215,6 @@ func (mockDatasource) GetStoreEngine() types.Engine {
return types.FileStoreEngine
}
// GetCustomDomainsCounts returns test custom domain counts.
func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) {
return 3, 2, nil
}
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
func TestGenerateProperties(t *testing.T) {
ds := mockDatasource{}
@@ -286,14 +247,14 @@ func TestGenerateProperties(t *testing.T) {
if properties["rules"] != 4 {
t.Errorf("expected 4 rules, got %d", properties["rules"])
}
if properties["users"] != 3 {
t.Errorf("expected 3 users, got %d", properties["users"])
if properties["users"] != 2 {
t.Errorf("expected 1 users, got %d", properties["users"])
}
if properties["setup_keys_usage"] != 2 {
t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"])
}
if properties["pats"] != 5 {
t.Errorf("expected 5 personal_access_tokens, got %d", properties["pats"])
if properties["pats"] != 4 {
t.Errorf("expected 4 personal_access_tokens, got %d", properties["pats"])
}
if properties["peers_ssh_enabled"] != 2 {
t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"])
@@ -377,90 +338,7 @@ func TestGenerateProperties(t *testing.T) {
if properties["local_users_count"] != 1 {
t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"])
}
if properties["idp_users_count"] != 2 {
t.Errorf("expected 2 idp_users_count, got %d", properties["idp_users_count"])
}
if properties["embedded_idp_users_local"] != 1 {
t.Errorf("expected 1 embedded_idp_users_local, got %v", properties["embedded_idp_users_local"])
}
if properties["embedded_idp_users_zitadel"] != 1 {
t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"])
}
if properties["embedded_idp_users_oidc"] != 1 {
t.Errorf("expected 1 embedded_idp_users_oidc, got %v", properties["embedded_idp_users_oidc"])
}
if properties["embedded_idp_count"] != 3 {
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
}
if properties["services"] != 2 {
t.Errorf("expected 2 services, got %v", properties["services"])
}
if properties["services_enabled"] != 1 {
t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"])
}
if properties["services_targets"] != 3 {
t.Errorf("expected 3 services_targets, got %v", properties["services_targets"])
}
if properties["services_status_active"] != 1 {
t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"])
}
if properties["services_status_pending"] != 1 {
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
}
if properties["services_status_error"] != 0 {
t.Errorf("expected 0 services_status_error, got %v", properties["services_status_error"])
}
if properties["services_target_type_peer"] != 1 {
t.Errorf("expected 1 services_target_type_peer, got %v", properties["services_target_type_peer"])
}
if properties["services_target_type_host"] != 1 {
t.Errorf("expected 1 services_target_type_host, got %v", properties["services_target_type_host"])
}
if properties["services_target_type_domain"] != 1 {
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
}
if properties["services_auth_password"] != 1 {
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
}
if properties["services_auth_oidc"] != 1 {
t.Errorf("expected 1 services_auth_oidc, got %v", properties["services_auth_oidc"])
}
if properties["services_auth_pin"] != 0 {
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
}
if properties["custom_domains"] != int64(3) {
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
}
if properties["custom_domains_validated"] != int64(2) {
t.Errorf("expected 2 custom_domains_validated, got %v", properties["custom_domains_validated"])
}
}
func TestExtractIdpType(t *testing.T) {
tests := []struct {
connectorID string
expected string
}{
{"okta-abc123def", "okta"},
{"zitadel-d5uv82dra0haedlf6kv0", "zitadel"},
{"entra-xyz789", "entra"},
{"google-abc123", "google"},
{"pocketid-abc123", "pocketid"},
{"microsoft-abc123", "microsoft"},
{"authentik-abc123", "authentik"},
{"keycloak-d5uv82dra0haedlf6kv0", "keycloak"},
{"local", "local"},
{"d6jvvp69kmnc73c9pl40", "oidc"},
{"", "oidc"},
}
for _, tt := range tests {
t.Run(tt.connectorID, func(t *testing.T) {
result := extractIdpType(tt.connectorID)
if result != tt.expected {
t.Errorf("extractIdpType(%q) = %q, want %q", tt.connectorID, result, tt.expected)
}
})
if properties["idp_users_count"] != 1 {
t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"])
}
}

View File

@@ -407,7 +407,7 @@ func (am *MockAccountManager) AddPeer(
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
if am.GetGroupByNameFunc != nil {
if am.GetGroupFunc != nil {
return am.GetGroupByNameFunc(ctx, accountID, groupName)
}
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")

View File

@@ -7,7 +7,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources/types"

View File

@@ -352,10 +352,9 @@ func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest)
p.Name = a.Name
p.Key = a.WgPubKey
p.Meta = PeerSystemMeta{
Hostname: a.Name,
GoOS: "js",
OS: "js",
KernelVersion: "wasm",
Hostname: a.Name,
GoOS: "js",
OS: "js",
}
}

View File

@@ -269,8 +269,3 @@ func (s *FileStore) GetStoreEngine() types.Engine {
func (s *FileStore) SetFieldEncrypt(_ *crypt.FieldEncrypt) {
// no-op: FileStore stores data in plaintext JSON; encryption is not supported
}
// GetCustomDomainsCounts is a no-op for FileStore as it doesn't support custom domains.
func (s *FileStore) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) {
return 0, 0, nil
}

View File

@@ -1008,18 +1008,6 @@ func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
return count, nil
}
// GetCustomDomainsCounts returns the total and validated custom domain counts.
func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
var total, validated int64
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Count(&total).Error; err != nil {
return 0, 0, err
}
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
return 0, 0, err
}
return total, validated, nil
}
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
var accounts []types.Account
result := s.db.Find(&accounts)
@@ -2122,13 +2110,12 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
}
}
s.Meta = rpservice.Meta{}
s.Meta = rpservice.ServiceMeta{}
if createdAt.Valid {
s.Meta.CreatedAt = createdAt.Time
}
if certIssuedAt.Valid {
t := certIssuedAt.Time
s.Meta.CertificateIssuedAt = &t
s.Meta.CertificateIssuedAt = certIssuedAt.Time
}
if status.Valid {
s.Meta.Status = status.String
@@ -2729,28 +2716,14 @@ func (s *SqlStore) GetStoreEngine() types.Engine {
// NewSqliteStore creates a new SQLite store.
func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
storeFile := storeSqliteFileName
if envFile, ok := os.LookupEnv("NB_STORE_ENGINE_SQLITE_FILE"); ok && envFile != "" {
storeFile = envFile
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
if runtime.GOOS == "windows" {
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
storeStr = storeSqliteFileName
}
// Separate file path from any SQLite URI query parameters (e.g., "store.db?mode=rwc")
filePath, query, hasQuery := strings.Cut(storeFile, "?")
connStr := filePath
if !filepath.IsAbs(filePath) {
connStr = filepath.Join(dataDir, filePath)
}
// Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows
if hasQuery {
connStr += "?" + query
} else if runtime.GOOS != "windows" {
// To avoid `The process cannot access the file because it is being used by another process` on Windows
connStr += "?cache=shared"
}
db, err := gorm.Open(sqlite.Open(connStr), getGormConfig())
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
if err != nil {
return nil, err
}
@@ -4913,46 +4886,6 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
return nil
}
func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error {
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete target from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "target not found for service %s", serviceID)
}
return nil
}
func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error {
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete targets from store")
}
return nil
}
// GetTargetsByServiceID retrieves all targets for a given service
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) {
var targets []*rpservice.Target
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
result := tx.Where("account_id = ? AND service_id = ?", accountID, serviceID).Find(&targets)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get targets from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get targets from store")
}
return targets, nil
}
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
@@ -5040,99 +4973,6 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS
return serviceList, nil
}
// RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service.
func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Update("meta_last_renewed_at", time.Now())
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error)
return status.Errorf(status.Internal, "renew ephemeral service")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
}
return nil
}
// GetExpiredEphemeralServices returns ephemeral services whose last renewal exceeds the given TTL.
// Only the fields needed for reaping are selected. The limit parameter caps the batch size to
// avoid loading too many rows in a single tick. Rows with empty source_peer are excluded to
// skip malformed legacy data.
func (s *SqlStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) {
cutoff := time.Now().Add(-ttl)
var services []*rpservice.Service
result := s.db.
Select("id", "account_id", "source_peer", "domain").
Where("source = ? AND source_peer <> '' AND meta_last_renewed_at < ?", rpservice.SourceEphemeral, cutoff).
Limit(limit).
Find(&services)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get expired ephemeral services: %v", result.Error)
return nil, status.Errorf(status.Internal, "get expired ephemeral services")
}
return services, nil
}
// CountEphemeralServicesByPeer returns the count of ephemeral services for a specific peer.
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
// The locking is applied via a row-level SELECT ... FOR UPDATE (not on the aggregate) to
// stay compatible with Postgres, which disallows FOR UPDATE on COUNT(*).
func (s *SqlStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
if lockStrength == LockingStrengthNone {
var count int64
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
Count(&count)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
return 0, status.Errorf(status.Internal, "count ephemeral services")
}
return count, nil
}
var ids []string
result := s.db.Model(&rpservice.Service{}).
Clauses(clause.Locking{Strength: string(lockStrength)}).
Select("id").
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
Pluck("id", &ids)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
return 0, status.Errorf(status.Internal, "count ephemeral services")
}
return int64(len(ids)), nil
}
// EphemeralServiceExists checks if an ephemeral service exists for the given peer and domain.
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
if lockStrength == LockingStrengthNone {
var count int64
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Count(&count)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
return false, status.Errorf(status.Internal, "check ephemeral service existence")
}
return count > 0, nil
}
var id string
result := s.db.Model(&rpservice.Service{}).
Clauses(clause.Locking{Strength: string(lockStrength)}).
Select("id").
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Limit(1).
Pluck("id", &id)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
return false, status.Errorf(status.Internal, "check ephemeral service existence")
}
return id != "", nil
}
func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) {
tx := s.db

View File

@@ -261,11 +261,6 @@ type Store interface {
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error
GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error)
CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error)
EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
@@ -277,16 +272,11 @@ type Store interface {
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error)
GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error)
DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
}
const (

View File

@@ -208,21 +208,6 @@ func (mr *MockStoreMockRecorder) CountAccountsByPrivateDomain(ctx, domain interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountsByPrivateDomain", reflect.TypeOf((*MockStore)(nil).CountAccountsByPrivateDomain), ctx, domain)
}
// CountEphemeralServicesByPeer mocks base method.
func (m *MockStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountEphemeralServicesByPeer", ctx, lockStrength, accountID, peerID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountEphemeralServicesByPeer indicates an expected call of CountEphemeralServicesByPeer.
func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
}
// CreateAccessLog mocks base method.
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
m.ctrl.T.Helper()
@@ -589,20 +574,6 @@ func (mr *MockStoreMockRecorder) DeleteService(ctx, accountID, serviceID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockStore)(nil).DeleteService), ctx, accountID, serviceID)
}
// DeleteServiceTargets mocks base method.
func (m *MockStore) DeleteServiceTargets(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteServiceTargets", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteServiceTargets indicates an expected call of DeleteServiceTargets.
func (mr *MockStoreMockRecorder) DeleteServiceTargets(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteServiceTargets", reflect.TypeOf((*MockStore)(nil).DeleteServiceTargets), ctx, accountID, serviceID)
}
// DeleteSetupKey mocks base method.
func (m *MockStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
m.ctrl.T.Helper()
@@ -617,20 +588,6 @@ func (mr *MockStoreMockRecorder) DeleteSetupKey(ctx, accountID, keyID interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSetupKey", reflect.TypeOf((*MockStore)(nil).DeleteSetupKey), ctx, accountID, keyID)
}
// DeleteTarget mocks base method.
func (m *MockStore) DeleteTarget(ctx context.Context, accountID, serviceID string, targetID uint) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteTarget", ctx, accountID, serviceID, targetID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteTarget indicates an expected call of DeleteTarget.
func (mr *MockStoreMockRecorder) DeleteTarget(ctx, accountID, serviceID, targetID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTarget", reflect.TypeOf((*MockStore)(nil).DeleteTarget), ctx, accountID, serviceID, targetID)
}
// DeleteTokenID2UserIDIndex mocks base method.
func (m *MockStore) DeleteTokenID2UserIDIndex(tokenID string) error {
m.ctrl.T.Helper()
@@ -701,21 +658,6 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
}
// EphemeralServiceExists mocks base method.
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EphemeralServiceExists", ctx, lockStrength, accountID, peerID, domain)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// EphemeralServiceExists indicates an expected call of EphemeralServiceExists.
func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
}
// ExecuteInTransaction mocks base method.
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
m.ctrl.T.Helper()
@@ -1361,22 +1303,6 @@ func (mr *MockStoreMockRecorder) GetCustomDomain(ctx, accountID, domainID interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomain", reflect.TypeOf((*MockStore)(nil).GetCustomDomain), ctx, accountID, domainID)
}
// GetCustomDomainsCounts mocks base method.
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(int64)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
}
// GetDNSRecordByID mocks base method.
func (m *MockStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
m.ctrl.T.Helper()
@@ -1392,21 +1318,6 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID)
}
// GetExpiredEphemeralServices mocks base method.
func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetExpiredEphemeralServices", ctx, ttl, limit)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetExpiredEphemeralServices indicates an expected call of GetExpiredEphemeralServices.
func (mr *MockStoreMockRecorder) GetExpiredEphemeralServices(ctx, ttl, limit interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiredEphemeralServices", reflect.TypeOf((*MockStore)(nil).GetExpiredEphemeralServices), ctx, ttl, limit)
}
// GetGroupByID mocks base method.
func (m *MockStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types2.Group, error) {
m.ctrl.T.Helper()
@@ -2050,21 +1961,6 @@ func (mr *MockStoreMockRecorder) GetTakenIPs(ctx, lockStrength, accountId interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTakenIPs", reflect.TypeOf((*MockStore)(nil).GetTakenIPs), ctx, lockStrength, accountId)
}
// GetTargetsByServiceID mocks base method.
func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*service.Target, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTargetsByServiceID", ctx, lockStrength, accountID, serviceID)
ret0, _ := ret[0].([]*service.Target)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTargetsByServiceID indicates an expected call of GetTargetsByServiceID.
func (mr *MockStoreMockRecorder) GetTargetsByServiceID(ctx, lockStrength, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTargetsByServiceID", reflect.TypeOf((*MockStore)(nil).GetTargetsByServiceID), ctx, lockStrength, accountID, serviceID)
}
// GetTokenIDByHashedToken mocks base method.
func (m *MockStore) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) {
m.ctrl.T.Helper()
@@ -2446,20 +2342,6 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResourceFromGroup", reflect.TypeOf((*MockStore)(nil).RemoveResourceFromGroup), ctx, accountId, groupID, resourceID)
}
// RenewEphemeralService mocks base method.
func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// RenewEphemeralService indicates an expected call of RenewEphemeralService.
func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, domain)
}
// RevokeProxyAccessToken mocks base method.
func (m *MockStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
m.ctrl.T.Helper()

View File

@@ -1,576 +0,0 @@
package types
import (
"context"
"slices"
"time"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
func (a *Account) GetPeerNetworkMapFromComponents(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) *NetworkMap {
start := time.Now()
components := a.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
accountZones,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
if components == nil {
return &NetworkMap{Network: a.Network.Copy()}
}
nm := CalculateNetworkMapFromComponents(ctx, components)
if metrics != nil {
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
if objectCount > 5000 {
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from components, "+
"peers: %d, offline peers: %d, routes: %d, firewall rules: %d, route firewall rules: %d",
a.Id, objectCount, len(nm.Peers), len(nm.OfflinePeers), len(nm.Routes), len(nm.FirewallRules), len(nm.RoutesFirewallRules))
}
}
return nm
}
func (a *Account) GetPeerNetworkMapComponents(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
groupIDToUserIDs map[string][]string,
) *NetworkMapComponents {
peer := a.Peers[peerID]
if peer == nil {
return nil
}
if _, ok := validatedPeersMap[peerID]; !ok {
return nil
}
components := &NetworkMapComponents{
PeerID: peerID,
Network: a.Network.Copy(),
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
CustomZoneDomain: peersCustomZone.Domain,
ResourcePoliciesMap: make(map[string][]*Policy),
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
RouterPeers: make(map[string]*nbpeer.Peer),
}
components.AccountSettings = &AccountSettingsInfo{
PeerLoginExpirationEnabled: a.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: a.Settings.PeerLoginExpiration,
PeerInactivityExpirationEnabled: a.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: a.Settings.PeerInactivityExpiration,
}
components.DNSSettings = &a.DNSSettings
relevantPeers, relevantGroups, relevantPolicies, relevantRoutes, sshReqs := a.getPeersGroupsPoliciesRoutes(ctx, peerID, peer.SSHEnabled, validatedPeersMap, &components.PostureFailedPeers)
if len(sshReqs.neededGroupIDs) > 0 {
components.GroupIDToUserIDs = filterGroupIDToUserIDs(groupIDToUserIDs, sshReqs.neededGroupIDs)
}
if sshReqs.needAllowedUserIDs {
components.AllowedUserIDs = a.getAllowedUserIDs()
}
components.Peers = relevantPeers
components.Groups = relevantGroups
components.Policies = relevantPolicies
components.Routes = relevantRoutes
components.AllDNSRecords = filterDNSRecordsByPeers(peersCustomZone.Records, relevantPeers)
peerGroups := a.GetPeerGroups(peerID)
components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups)
for _, nsGroup := range a.NameServerGroups {
if nsGroup.Enabled {
for _, gID := range nsGroup.Groups {
if _, found := relevantGroups[gID]; found {
components.NameServerGroups = append(components.NameServerGroups, nsGroup)
break
}
}
}
}
for _, resource := range a.NetworkResources {
if !resource.Enabled {
continue
}
policies, exists := resourcePolicies[resource.ID]
if !exists {
continue
}
addSourcePeers := false
networkRoutingPeers, routerExists := routers[resource.NetworkID]
if routerExists {
if _, ok := networkRoutingPeers[peerID]; ok {
addSourcePeers = true
}
}
for _, policy := range policies {
if addSourcePeers {
var peers []string
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peers = []string{policy.Rules[0].SourceResource.ID}
} else {
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
}
for _, pID := range a.getPostureValidPeersSaveFailed(peers, policy.SourcePostureChecks, validatedPeersMap, &components.PostureFailedPeers) {
if _, exists := components.Peers[pID]; !exists {
components.Peers[pID] = a.GetPeer(pID)
}
}
} else {
peerInSources := false
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peerInSources = policy.Rules[0].SourceResource.ID == peerID
} else {
for _, groupID := range policy.SourceGroups() {
if group := a.GetGroup(groupID); group != nil && slices.Contains(group.Peers, peerID) {
peerInSources = true
break
}
}
}
if !peerInSources {
continue
}
isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, policy.SourcePostureChecks, peerID)
if !isValid && len(pname) > 0 {
if _, ok := components.PostureFailedPeers[pname]; !ok {
components.PostureFailedPeers[pname] = make(map[string]struct{})
}
components.PostureFailedPeers[pname][peer.ID] = struct{}{}
continue
}
addSourcePeers = true
}
for _, rule := range policy.Rules {
for _, srcGroupID := range rule.Sources {
if g := a.Groups[srcGroupID]; g != nil {
if _, exists := components.Groups[srcGroupID]; !exists {
components.Groups[srcGroupID] = g
}
}
}
for _, dstGroupID := range rule.Destinations {
if g := a.Groups[dstGroupID]; g != nil {
if _, exists := components.Groups[dstGroupID]; !exists {
components.Groups[dstGroupID] = g
}
}
}
}
components.ResourcePoliciesMap[resource.ID] = policies
}
components.RoutersMap[resource.NetworkID] = networkRoutingPeers
for peerIDKey := range networkRoutingPeers {
if p := a.Peers[peerIDKey]; p != nil {
if _, exists := components.RouterPeers[peerIDKey]; !exists {
components.RouterPeers[peerIDKey] = p
}
if _, exists := components.Peers[peerIDKey]; !exists {
if _, validated := validatedPeersMap[peerIDKey]; validated {
components.Peers[peerIDKey] = p
}
}
}
}
if addSourcePeers {
components.NetworkResources = append(components.NetworkResources, resource)
}
}
filterGroupPeers(&components.Groups, components.Peers)
filterPostureFailedPeers(&components.PostureFailedPeers, components.Policies, components.ResourcePoliciesMap, components.Peers)
return components
}
type sshRequirements struct {
neededGroupIDs map[string]struct{}
needAllowedUserIDs bool
}
func (a *Account) getPeersGroupsPoliciesRoutes(
ctx context.Context,
peerID string,
peerSSHEnabled bool,
validatedPeersMap map[string]struct{},
postureFailedPeers *map[string]map[string]struct{},
) (map[string]*nbpeer.Peer, map[string]*Group, []*Policy, []*route.Route, sshRequirements) {
relevantPeerIDs := make(map[string]*nbpeer.Peer, len(a.Peers)/4)
relevantGroupIDs := make(map[string]*Group, len(a.Groups)/4)
relevantPolicies := make([]*Policy, 0, len(a.Policies))
relevantRoutes := make([]*route.Route, 0, len(a.Routes))
sshReqs := sshRequirements{neededGroupIDs: make(map[string]struct{})}
relevantPeerIDs[peerID] = a.GetPeer(peerID)
for groupID, group := range a.Groups {
if slices.Contains(group.Peers, peerID) {
relevantGroupIDs[groupID] = a.GetGroup(groupID)
}
}
routeAccessControlGroups := make(map[string]struct{})
for _, r := range a.Routes {
for _, groupID := range r.Groups {
relevantGroupIDs[groupID] = a.GetGroup(groupID)
}
for _, groupID := range r.PeerGroups {
relevantGroupIDs[groupID] = a.GetGroup(groupID)
}
if r.Enabled {
for _, groupID := range r.AccessControlGroups {
relevantGroupIDs[groupID] = a.GetGroup(groupID)
routeAccessControlGroups[groupID] = struct{}{}
}
}
relevantRoutes = append(relevantRoutes, r)
}
for _, policy := range a.Policies {
if !policy.Enabled {
continue
}
policyRelevant := false
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
if len(routeAccessControlGroups) > 0 {
for _, destGroupID := range rule.Destinations {
if _, needed := routeAccessControlGroups[destGroupID]; needed {
policyRelevant = true
for _, srcGroupID := range rule.Sources {
relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID)
}
for _, dstGroupID := range rule.Destinations {
relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID)
}
break
}
}
}
var sourcePeers, destinationPeers []string
var peerInSources, peerInDestinations bool
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
sourcePeers = []string{rule.SourceResource.ID}
if rule.SourceResource.ID == peerID {
peerInSources = true
}
} else {
sourcePeers, peerInSources = a.getPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap, postureFailedPeers)
}
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
destinationPeers = []string{rule.DestinationResource.ID}
if rule.DestinationResource.ID == peerID {
peerInDestinations = true
}
} else {
destinationPeers, peerInDestinations = a.getPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap, postureFailedPeers)
}
if peerInSources {
policyRelevant = true
for _, pid := range destinationPeers {
relevantPeerIDs[pid] = a.GetPeer(pid)
}
for _, dstGroupID := range rule.Destinations {
relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID)
}
}
if peerInDestinations {
policyRelevant = true
for _, pid := range sourcePeers {
relevantPeerIDs[pid] = a.GetPeer(pid)
}
for _, srcGroupID := range rule.Sources {
relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID)
}
if rule.Protocol == PolicyRuleProtocolNetbirdSSH {
switch {
case len(rule.AuthorizedGroups) > 0:
for groupID := range rule.AuthorizedGroups {
sshReqs.neededGroupIDs[groupID] = struct{}{}
}
case rule.AuthorizedUser != "":
default:
sshReqs.needAllowedUserIDs = true
}
} else if policyRuleImpliesLegacySSH(rule) && peerSSHEnabled {
sshReqs.needAllowedUserIDs = true
}
}
}
if policyRelevant {
relevantPolicies = append(relevantPolicies, policy)
}
}
return relevantPeerIDs, relevantGroupIDs, relevantPolicies, relevantRoutes, sshReqs
}
func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string,
validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) ([]string, bool) {
peerInGroups := false
filteredPeerIDs := make([]string, 0, len(a.Peers))
seenPeerIds := make(map[string]struct{}, len(groups))
for _, gid := range groups {
group := a.GetGroup(gid)
if group == nil {
continue
}
if group.IsGroupAll() || len(groups) == 1 {
filteredPeerIDs = filteredPeerIDs[:0]
peerInGroups = false
for _, pid := range group.Peers {
peer, ok := a.Peers[pid]
if !ok || peer == nil {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid && len(pname) > 0 {
if _, ok := (*postureFailedPeers)[pname]; !ok {
(*postureFailedPeers)[pname] = make(map[string]struct{})
}
(*postureFailedPeers)[pname][peer.ID] = struct{}{}
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeerIDs = append(filteredPeerIDs, peer.ID)
}
return filteredPeerIDs, peerInGroups
}
for _, pid := range group.Peers {
if _, seen := seenPeerIds[pid]; seen {
continue
}
seenPeerIds[pid] = struct{}{}
peer, ok := a.Peers[pid]
if !ok || peer == nil {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid && len(pname) > 0 {
if _, ok := (*postureFailedPeers)[pname]; !ok {
(*postureFailedPeers)[pname] = make(map[string]struct{})
}
(*postureFailedPeers)[pname][peer.ID] = struct{}{}
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeerIDs = append(filteredPeerIDs, peer.ID)
}
}
return filteredPeerIDs, peerInGroups
}
func (a *Account) validatePostureChecksOnPeerGetFailed(ctx context.Context, sourcePostureChecksID []string, peerID string) (bool, string) {
peer, ok := a.Peers[peerID]
if !ok || peer == nil {
return false, ""
}
for _, postureChecksID := range sourcePostureChecksID {
postureChecks := a.GetPostureChecks(postureChecksID)
if postureChecks == nil {
continue
}
for _, check := range postureChecks.GetChecks() {
isValid, _ := check.Check(ctx, *peer)
if !isValid {
return false, postureChecksID
}
}
}
return true, ""
}
func (a *Account) getPostureValidPeersSaveFailed(inputPeers []string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) []string {
var dest []string
for _, peerID := range inputPeers {
if _, validated := validatedPeersMap[peerID]; !validated {
continue
}
valid, pname := a.validatePostureChecksOnPeerGetFailed(context.Background(), postureChecksIDs, peerID)
if valid {
dest = append(dest, peerID)
continue
}
if _, ok := (*postureFailedPeers)[pname]; !ok {
(*postureFailedPeers)[pname] = make(map[string]struct{})
}
(*postureFailedPeers)[pname][peerID] = struct{}{}
}
return dest
}
func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) {
for groupID, groupInfo := range *groups {
filteredPeers := make([]string, 0, len(groupInfo.Peers))
for _, pid := range groupInfo.Peers {
if _, exists := peers[pid]; exists {
filteredPeers = append(filteredPeers, pid)
}
}
if len(filteredPeers) == 0 {
delete(*groups, groupID)
} else if len(filteredPeers) != len(groupInfo.Peers) {
ng := groupInfo.Copy()
ng.Peers = filteredPeers
(*groups)[groupID] = ng
}
}
}
func filterPostureFailedPeers(postureFailedPeers *map[string]map[string]struct{}, policies []*Policy, resourcePoliciesMap map[string][]*Policy, peers map[string]*nbpeer.Peer) {
if len(*postureFailedPeers) == 0 {
return
}
referencedPostureChecks := make(map[string]struct{})
for _, policy := range policies {
for _, checkID := range policy.SourcePostureChecks {
referencedPostureChecks[checkID] = struct{}{}
}
}
for _, resPolicies := range resourcePoliciesMap {
for _, policy := range resPolicies {
for _, checkID := range policy.SourcePostureChecks {
referencedPostureChecks[checkID] = struct{}{}
}
}
}
for checkID, failedPeers := range *postureFailedPeers {
if _, referenced := referencedPostureChecks[checkID]; !referenced {
delete(*postureFailedPeers, checkID)
continue
}
for peerID := range failedPeers {
if _, exists := peers[peerID]; !exists {
delete(failedPeers, peerID)
}
}
if len(failedPeers) == 0 {
delete(*postureFailedPeers, checkID)
}
}
}
func filterDNSRecordsByPeers(records []nbdns.SimpleRecord, peers map[string]*nbpeer.Peer) []nbdns.SimpleRecord {
if len(records) == 0 || len(peers) == 0 {
return nil
}
peerIPs := make(map[string]struct{}, len(peers))
for _, peer := range peers {
if peer != nil {
peerIPs[peer.IP.String()] = struct{}{}
}
}
filteredRecords := make([]nbdns.SimpleRecord, 0, len(records))
for _, record := range records {
if _, exists := peerIPs[record.RData]; exists {
filteredRecords = append(filteredRecords, record)
}
}
return filteredRecords
}
func filterGroupIDToUserIDs(fullMap map[string][]string, neededGroupIDs map[string]struct{}) map[string][]string {
if len(neededGroupIDs) == 0 {
return nil
}
filtered := make(map[string][]string, len(neededGroupIDs))
for groupID := range neededGroupIDs {
if users, ok := fullMap[groupID]; ok {
filtered[groupID] = users
}
}
return filtered
}

View File

@@ -1,592 +0,0 @@
package types
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"path/filepath"
"sort"
"testing"
"time"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route"
)
func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
if components == nil {
t.Fatal("GetPeerNetworkMapComponents returned nil")
}
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
if newNetworkMap == nil {
t.Fatal("CalculateNetworkMapFromComponents returned nil")
}
compareNetworkMaps(t, legacyNetworkMap, newNetworkMap)
}
func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil")
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil")
normalizeAndSortNetworkMap(legacyNetworkMap)
normalizeAndSortNetworkMap(newNetworkMap)
componentsJSON, err := json.MarshalIndent(components, "", " ")
require.NoError(t, err, "error marshaling components to JSON")
legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
newJSON, err := json.MarshalIndent(newNetworkMap, "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
goldenDir := filepath.Join("testdata", "comparison")
err = os.MkdirAll(goldenDir, 0755)
require.NoError(t, err)
legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json")
err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644)
require.NoError(t, err, "error writing legacy golden file")
newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json")
err = os.WriteFile(newGoldenPath, newJSON, 0644)
require.NoError(t, err, "error writing components golden file")
componentsPath := filepath.Join(goldenDir, "components.json")
err = os.WriteFile(componentsPath, componentsJSON, 0644)
require.NoError(t, err, "error writing components golden file")
require.JSONEq(t, string(legacyJSON), string(newJSON),
"NetworkMaps from legacy and components approaches do not match.\n"+
"Legacy JSON saved to: %s\n"+
"Components JSON saved to: %s",
legacyGoldenPath, newGoldenPath)
t.Logf("✅ NetworkMaps are identical")
t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath)
t.Logf(" Components NetworkMap: %s", newGoldenPath)
}
func normalizeAndSortNetworkMap(nm *NetworkMap) {
if nm == nil {
return
}
sort.Slice(nm.Peers, func(i, j int) bool {
return nm.Peers[i].ID < nm.Peers[j].ID
})
sort.Slice(nm.OfflinePeers, func(i, j int) bool {
return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID
})
sort.Slice(nm.Routes, func(i, j int) bool {
return string(nm.Routes[i].ID) < string(nm.Routes[j].ID)
})
sort.Slice(nm.FirewallRules, func(i, j int) bool {
if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP {
return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP
}
if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction {
return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction
}
if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol {
return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol
}
if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port {
return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port
}
return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID
})
for i := range nm.RoutesFirewallRules {
sort.Strings(nm.RoutesFirewallRules[i].SourceRanges)
}
sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool {
if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination {
return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination
}
minLen := len(nm.RoutesFirewallRules[i].SourceRanges)
if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen {
minLen = len(nm.RoutesFirewallRules[j].SourceRanges)
}
for k := 0; k < minLen; k++ {
if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] {
return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k]
}
}
if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) {
return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges)
}
if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) {
return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID)
}
if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID {
return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID
}
if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port {
return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port
}
return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol
})
if nm.DNSConfig.CustomZones != nil {
for i := range nm.DNSConfig.CustomZones {
sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool {
return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name
})
}
}
if len(nm.DNSConfig.NameServerGroups) != 0 {
sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool {
return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name
})
}
}
func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) {
t.Helper()
if legacy.Network.Serial != current.Network.Serial {
t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial)
}
if len(legacy.Peers) != len(current.Peers) {
t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers))
}
legacyPeerIDs := make(map[string]bool)
for _, p := range legacy.Peers {
legacyPeerIDs[p.ID] = true
}
for _, p := range current.Peers {
if !legacyPeerIDs[p.ID] {
t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID)
}
}
if len(legacy.OfflinePeers) != len(current.OfflinePeers) {
t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers))
}
if len(legacy.FirewallRules) != len(current.FirewallRules) {
t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules))
}
if len(legacy.Routes) != len(current.Routes) {
t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes))
}
if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) {
t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules))
}
if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable {
t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable)
}
}
const (
numPeers = 100
devGroupID = "group-dev"
opsGroupID = "group-ops"
allGroupID = "group-all"
routeID = route.ID("route-main")
routeHA1ID = route.ID("route-ha-1")
routeHA2ID = route.ID("route-ha-2")
policyIDDevOps = "policy-dev-ops"
policyIDAll = "policy-all"
policyIDPosture = "policy-posture"
policyIDDrop = "policy-drop"
postureCheckID = "posture-check-ver"
networkResourceID = "res-database"
networkID = "net-database"
networkRouterID = "router-database"
nameserverGroupID = "ns-group-main"
testingPeerID = "peer-60"
expiredPeerID = "peer-98"
offlinePeerID = "peer-99"
routingPeerID = "peer-95"
testAccountID = "account-comparison-test"
)
func createTestAccount() *Account {
peers := make(map[string]*nbpeer.Peer)
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
ip := net.IP{100, 64, 0, byte(i + 1)}
wtVersion := "0.25.0"
if i%2 == 0 {
wtVersion = "0.40.0"
}
p := &nbpeer.Peer{
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
}
if peerID == expiredPeerID {
p.LoginExpirationEnabled = true
pastTimestamp := time.Now().Add(-2 * time.Hour)
p.LastLogin = &pastTimestamp
}
peers[peerID] = p
allGroupPeers = append(allGroupPeers, peerID)
if i < numPeers/2 {
devGroupPeers = append(devGroupPeers, peerID)
} else {
opsGroupPeers = append(opsGroupPeers, peerID)
}
}
groups := map[string]*Group{
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
}
policies := []*Policy{
{
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
}},
},
{
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolTCP, Bidirectional: false,
PortRanges: []RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop,
Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
SourcePostureChecks: []string{postureCheckID},
Rules: []*PolicyRule{{
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID},
}},
},
}
routes := map[route.ID]*route.Route{
routeID: {
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-75"].Key,
PeerID: "peer-75",
Description: "Route to internal resource", Enabled: true,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
},
routeHA1ID: {
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-80"].Key,
PeerID: "peer-80",
Description: "HA Route 1", Enabled: true, Metric: 1000,
PeerGroups: []string{allGroupID},
Groups: []string{allGroupID},
AccessControlGroups: []string{allGroupID},
},
routeHA2ID: {
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-90"].Key,
PeerID: "peer-90",
Description: "HA Route 2", Enabled: true, Metric: 900,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{allGroupID},
},
}
account := &Account{
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Network: &Network{
Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
NameServerGroups: map[string]*nbdns.NameServerGroup{
nameserverGroupID: {
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
}},
},
NetworkResources: []*resourceTypes.NetworkResource{
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
},
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
},
Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}
func BenchmarkLegacyNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
}
}
func BenchmarkComponentsNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}
func BenchmarkComponentsCreation(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
}
}
func BenchmarkCalculationFromComponents(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}

View File

@@ -1,938 +0,0 @@
package types
import (
"context"
"maps"
"net"
"net/netip"
"slices"
"strconv"
"strings"
"time"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED"
type NetworkMapComponents struct {
PeerID string
Network *Network
AccountSettings *AccountSettingsInfo
DNSSettings *DNSSettings
CustomZoneDomain string
Peers map[string]*nbpeer.Peer
Groups map[string]*Group
Policies []*Policy
Routes []*route.Route
NameServerGroups []*nbdns.NameServerGroup
AllDNSRecords []nbdns.SimpleRecord
AccountZones []nbdns.CustomZone
ResourcePoliciesMap map[string][]*Policy
RoutersMap map[string]map[string]*routerTypes.NetworkRouter
NetworkResources []*resourceTypes.NetworkResource
GroupIDToUserIDs map[string][]string
AllowedUserIDs map[string]struct{}
PostureFailedPeers map[string]map[string]struct{}
RouterPeers map[string]*nbpeer.Peer
}
type AccountSettingsInfo struct {
PeerLoginExpirationEnabled bool
PeerLoginExpiration time.Duration
PeerInactivityExpirationEnabled bool
PeerInactivityExpiration time.Duration
}
func (c *NetworkMapComponents) GetPeerInfo(peerID string) *nbpeer.Peer {
return c.Peers[peerID]
}
func (c *NetworkMapComponents) GetRouterPeerInfo(peerID string) *nbpeer.Peer {
return c.RouterPeers[peerID]
}
func (c *NetworkMapComponents) GetGroupInfo(groupID string) *Group {
return c.Groups[groupID]
}
func (c *NetworkMapComponents) IsPeerInGroup(peerID, groupID string) bool {
group := c.GetGroupInfo(groupID)
if group == nil {
return false
}
return slices.Contains(group.Peers, peerID)
}
func (c *NetworkMapComponents) GetPeerGroups(peerID string) map[string]struct{} {
groups := make(map[string]struct{})
for groupID, group := range c.Groups {
if slices.Contains(group.Peers, peerID) {
groups[groupID] = struct{}{}
}
}
return groups
}
func (c *NetworkMapComponents) ValidatePostureChecksOnPeer(peerID string, postureCheckIDs []string) bool {
_, exists := c.Peers[peerID]
if !exists {
return false
}
if len(postureCheckIDs) == 0 {
return true
}
for _, checkID := range postureCheckIDs {
if failedPeers, exists := c.PostureFailedPeers[checkID]; exists {
if _, failed := failedPeers[peerID]; failed {
return false
}
}
}
return true
}
func CalculateNetworkMapFromComponents(ctx context.Context, components *NetworkMapComponents) *NetworkMap {
return components.Calculate(ctx)
}
func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap {
targetPeerID := c.PeerID
peerGroups := c.GetPeerGroups(targetPeerID)
aclPeers, firewallRules, authorizedUsers, sshEnabled := c.getPeerConnectionResources(targetPeerID)
peersToConnect, expiredPeers := c.filterPeersByLoginExpiration(aclPeers)
routesUpdate := c.getRoutesToSync(targetPeerID, peersToConnect, peerGroups)
routesFirewallRules := c.getPeerRoutesFirewallRules(ctx, targetPeerID)
isRouter, networkResourcesRoutes, sourcePeers := c.getNetworkResourcesRoutesToSync(targetPeerID)
var networkResourcesFirewallRules []*RouteFirewallRule
if isRouter {
networkResourcesFirewallRules = c.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes)
}
peersToConnectIncludingRouters := c.addNetworksRoutingPeers(
networkResourcesRoutes,
targetPeerID,
peersToConnect,
expiredPeers,
isRouter,
sourcePeers,
)
dnsManagementStatus := c.getPeerDNSManagementStatus(targetPeerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var customZones []nbdns.CustomZone
if c.CustomZoneDomain != "" && len(c.AllDNSRecords) > 0 {
customZones = append(customZones, nbdns.CustomZone{
Domain: c.CustomZoneDomain,
Records: c.AllDNSRecords,
})
}
customZones = append(customZones, c.AccountZones...)
dnsUpdate.CustomZones = customZones
dnsUpdate.NameServerGroups = c.getPeerNSGroups(targetPeerID)
}
return &NetworkMap{
Peers: peersToConnectIncludingRouters,
Network: c.Network.Copy(),
Routes: append(networkResourcesRoutes, routesUpdate...),
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...),
AuthorizedUsers: authorizedUsers,
EnableSSH: sshEnabled,
}
}
func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
targetPeer := c.GetPeerInfo(targetPeerID)
if targetPeer == nil {
return nil, nil, nil, false
}
generateResources, getAccumulatedResources := c.connResourcesGenerator(targetPeer)
authorizedUsers := make(map[string]map[string]struct{})
sshEnabled := false
for _, policy := range c.Policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
var sourcePeers, destinationPeers []*nbpeer.Peer
var peerInSources, peerInDestinations bool
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
sourcePeers, peerInSources = c.getPeerFromResource(rule.SourceResource, targetPeerID)
} else {
sourcePeers, peerInSources = c.getAllPeersFromGroups(rule.Sources, targetPeerID, policy.SourcePostureChecks)
}
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
destinationPeers, peerInDestinations = c.getPeerFromResource(rule.DestinationResource, targetPeerID)
} else {
destinationPeers, peerInDestinations = c.getAllPeersFromGroups(rule.Destinations, targetPeerID, nil)
}
if rule.Bidirectional {
if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionIN)
}
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT)
}
}
if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
}
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
sshEnabled = true
switch {
case len(rule.AuthorizedGroups) > 0:
for groupID, localUsers := range rule.AuthorizedGroups {
userIDs, ok := c.GroupIDToUserIDs[groupID]
if !ok {
continue
}
if len(localUsers) == 0 {
localUsers = []string{auth.Wildcard}
}
for _, localUser := range localUsers {
if authorizedUsers[localUser] == nil {
authorizedUsers[localUser] = make(map[string]struct{})
}
for _, userID := range userIDs {
authorizedUsers[localUser][userID] = struct{}{}
}
}
}
case rule.AuthorizedUser != "":
if authorizedUsers[auth.Wildcard] == nil {
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
}
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
default:
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
}
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled {
sshEnabled = true
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
}
}
}
peers, fwRules := getAccumulatedResources()
return peers, fwRules, authorizedUsers, sshEnabled
}
func (c *NetworkMapComponents) getAllowedUserIDs() map[string]struct{} {
if c.AllowedUserIDs != nil {
result := make(map[string]struct{}, len(c.AllowedUserIDs))
maps.Copy(result, c.AllowedUserIDs)
return result
}
return make(map[string]struct{})
}
func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0)
peers := make([]*nbpeer.Peer, 0)
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
for _, peer := range groupPeers {
if peer == nil {
continue
}
if _, ok := peersExists[peer.ID]; !ok {
peers = append(peers, peer)
peersExists[peer.ID] = struct{}{}
}
protocol := rule.Protocol
if protocol == PolicyRuleProtocolNetbirdSSH {
protocol = PolicyRuleProtocolTCP
}
fr := FirewallRule{
PolicyID: rule.ID,
PeerIP: net.IP(peer.IP).String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(protocol),
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
if _, ok := rulesExists[ruleID]; ok {
continue
}
rulesExists[ruleID] = struct{}{}
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
rules = append(rules, &fr)
continue
}
rules = append(rules, expandPortsAndRanges(fr, &PolicyRule{
ID: rule.ID,
Ports: rule.Ports,
PortRanges: rule.PortRanges,
Protocol: rule.Protocol,
Action: rule.Action,
}, targetPeer)...)
}
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules
}
}
func (c *NetworkMapComponents) getAllPeersFromGroups(groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) {
peerInGroups := false
uniquePeerIDs := c.getUniquePeerIDsFromGroupsIDs(groups)
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
for _, p := range uniquePeerIDs {
peerInfo := c.GetPeerInfo(p)
if peerInfo == nil {
continue
}
if _, ok := c.Peers[p]; !ok {
continue
}
if !c.ValidatePostureChecksOnPeer(p, sourcePostureChecksIDs) {
continue
}
if p == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peerInfo)
}
return filteredPeers, peerInGroups
}
func (c *NetworkMapComponents) getUniquePeerIDsFromGroupsIDs(groups []string) []string {
peerIDs := make(map[string]struct{}, len(groups))
for _, groupID := range groups {
group := c.GetGroupInfo(groupID)
if group == nil {
continue
}
if group.IsGroupAll() || len(groups) == 1 {
return group.Peers
}
for _, peerID := range group.Peers {
peerIDs[peerID] = struct{}{}
}
}
ids := make([]string, 0, len(peerIDs))
for peerID := range peerIDs {
ids = append(ids, peerID)
}
return ids
}
func (c *NetworkMapComponents) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) {
if resource.ID == peerID {
return []*nbpeer.Peer{}, true
}
peerInfo := c.GetPeerInfo(resource.ID)
if peerInfo == nil {
return []*nbpeer.Peer{}, false
}
return []*nbpeer.Peer{peerInfo}, false
}
func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) {
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(c.AccountSettings.PeerLoginExpiration)
if c.AccountSettings.PeerLoginExpirationEnabled && expired {
expiredPeers = append(expiredPeers, p)
continue
}
peersToConnect = append(peersToConnect, p)
}
return peersToConnect, expiredPeers
}
func (c *NetworkMapComponents) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := c.GetPeerGroups(peerID)
enabled := true
for _, groupID := range c.DNSSettings.DisabledManagementGroups {
if _, found := peerGroups[groupID]; found {
enabled = false
break
}
}
return enabled
}
func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServerGroup {
groupList := c.GetPeerGroups(peerID)
var peerNSGroups []*nbdns.NameServerGroup
for _, nsGroup := range c.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
targetPeerInfo := c.GetPeerInfo(peerID)
if targetPeerInfo != nil && !c.peerIsNameserver(targetPeerInfo, nsGroup) {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
}
}
}
return peerNSGroups
}
func (c *NetworkMapComponents) peerIsNameserver(peerInfo *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
for _, ns := range nsGroup.NameServers {
if peerInfo.IP.String() == ns.IP.String() {
return true
}
}
return false
}
func (c *NetworkMapComponents) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
routes, peerDisabledRoutes := c.getRoutingPeerRoutes(peerID)
peerRoutesMembership := make(LookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
}
for _, peer := range aclPeers {
activeRoutes, _ := c.getRoutingPeerRoutes(peer.ID)
groupFilteredRoutes := c.filterRoutesByGroups(activeRoutes, peerGroups)
filteredRoutes := c.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
return routes
}
func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
peerInfo := c.GetPeerInfo(peerID)
if peerInfo == nil {
peerInfo = c.GetRouterPeerInfo(peerID)
}
if peerInfo == nil {
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[route.ID]struct{})
takeRoute := func(r *route.Route) {
if _, ok := seenRoute[r.ID]; ok {
return
}
seenRoute[r.ID] = struct{}{}
routeObj := c.copyRoute(r)
routeObj.Peer = peerInfo.Key
if r.Enabled {
enabledRoutes = append(enabledRoutes, routeObj)
return
}
disabledRoutes = append(disabledRoutes, routeObj)
}
for _, r := range c.Routes {
for _, groupID := range r.PeerGroups {
group := c.GetGroupInfo(groupID)
if group == nil {
continue
}
for _, id := range group.Peers {
if id != peerID {
continue
}
newPeerRoute := c.copyRoute(r)
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id)
takeRoute(newPeerRoute)
break
}
}
if r.Peer == peerID {
takeRoute(c.copyRoute(r))
}
}
return enabledRoutes, disabledRoutes
}
func (c *NetworkMapComponents) copyRoute(r *route.Route) *route.Route {
var groups, accessControlGroups, peerGroups []string
var domains domain.List
if r.Groups != nil {
groups = append([]string{}, r.Groups...)
}
if r.AccessControlGroups != nil {
accessControlGroups = append([]string{}, r.AccessControlGroups...)
}
if r.PeerGroups != nil {
peerGroups = append([]string{}, r.PeerGroups...)
}
if r.Domains != nil {
domains = append(domain.List{}, r.Domains...)
}
return &route.Route{
ID: r.ID,
AccountID: r.AccountID,
Network: r.Network,
NetworkType: r.NetworkType,
Description: r.Description,
Peer: r.Peer,
PeerID: r.PeerID,
Metric: r.Metric,
Masquerade: r.Masquerade,
NetID: r.NetID,
Enabled: r.Enabled,
Groups: groups,
AccessControlGroups: accessControlGroups,
PeerGroups: peerGroups,
Domains: domains,
KeepRoute: r.KeepRoute,
SkipAutoApply: r.SkipAutoApply,
}
}
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
for _, groupID := range r.Groups {
_, found := groupListMap[groupID]
if found {
filteredRoutes = append(filteredRoutes, r)
break
}
}
}
return filteredRoutes
}
func (c *NetworkMapComponents) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
_, found := peerMemberships[string(r.GetHAUniqueID())]
if !found {
filteredRoutes = append(filteredRoutes, r)
}
}
return filteredRoutes
}
func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, peerID string) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0)
enabledRoutes, _ := c.getRoutingPeerRoutes(peerID)
for _, r := range enabledRoutes {
if len(r.AccessControlGroups) == 0 {
defaultPermit := c.getDefaultPermit(r)
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
continue
}
distributionPeers := c.getDistributionGroupsPeers(r)
for _, accessGroup := range r.AccessControlGroups {
policies := c.getAllRoutePoliciesFromGroups([]string{accessGroup})
rules := c.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...)
}
}
return routesFirewallRules
}
func (c *NetworkMapComponents) getDefaultPermit(r *route.Route) []*RouteFirewallRule {
var rules []*RouteFirewallRule
sources := []string{"0.0.0.0/0"}
if r.Network.Addr().Is6() {
sources = []string{"::/0"}
}
rule := RouteFirewallRule{
SourceRanges: sources,
Action: string(PolicyTrafficActionAccept),
Destination: r.Network.String(),
Protocol: string(PolicyRuleProtocolALL),
Domains: r.Domains,
IsDynamic: r.IsDynamic(),
RouteID: r.ID,
}
rules = append(rules, &rule)
if r.IsDynamic() {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
rules = append(rules, &ruleV6)
}
return rules
}
func (c *NetworkMapComponents) getDistributionGroupsPeers(r *route.Route) map[string]struct{} {
distPeers := make(map[string]struct{})
for _, id := range r.Groups {
group := c.GetGroupInfo(id)
if group == nil {
continue
}
for _, pID := range group.Peers {
distPeers[pID] = struct{}{}
}
}
return distPeers
}
func (c *NetworkMapComponents) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy {
routePolicies := make([]*Policy, 0)
for _, groupID := range accessControlGroups {
for _, policy := range c.Policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Destinations, groupID) {
routePolicies = append(routePolicies, policy)
}
}
}
}
return routePolicies
}
func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}) []*RouteFirewallRule {
var fwRules []*RouteFirewallRule
for _, policy := range policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
rulePeers := c.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers)
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
fwRules = append(fwRules, rules...)
}
}
return fwRules
}
func (c *NetworkMapComponents) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}) []*nbpeer.Peer {
distPeersWithPolicy := make(map[string]struct{})
for _, id := range rule.Sources {
group := c.GetGroupInfo(id)
if group == nil {
continue
}
for _, pID := range group.Peers {
if pID == peerID {
continue
}
_, distPeer := distributionPeers[pID]
_, valid := c.Peers[pID]
if distPeer && valid && c.ValidatePostureChecksOnPeer(pID, postureChecks) {
distPeersWithPolicy[pID] = struct{}{}
}
}
}
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
_, distPeer := distributionPeers[rule.SourceResource.ID]
_, valid := c.Peers[rule.SourceResource.ID]
if distPeer && valid && c.ValidatePostureChecksOnPeer(rule.SourceResource.ID, postureChecks) {
distPeersWithPolicy[rule.SourceResource.ID] = struct{}{}
}
}
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
for pID := range distPeersWithPolicy {
peerInfo := c.GetPeerInfo(pID)
if peerInfo == nil {
continue
}
distributionGroupPeers = append(distributionGroupPeers, peerInfo)
}
return distributionGroupPeers
}
func (c *NetworkMapComponents) getNetworkResourcesRoutesToSync(peerID string) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool
var routes []*route.Route
allSourcePeers := make(map[string]struct{})
for _, resource := range c.NetworkResources {
if !resource.Enabled {
continue
}
var addSourcePeers bool
networkRoutingPeers, exists := c.RoutersMap[resource.NetworkID]
if exists {
if router, ok := networkRoutingPeers[peerID]; ok {
isRoutingPeer, addSourcePeers = true, true
routes = append(routes, c.getNetworkResourcesRoutes(resource, peerID, router)...)
}
}
addedResourceRoute := false
for _, policy := range c.ResourcePoliciesMap[resource.ID] {
var peers []string
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peers = []string{policy.Rules[0].SourceResource.ID}
} else {
peers = c.getUniquePeerIDsFromGroupsIDs(policy.SourceGroups())
}
if addSourcePeers {
for _, pID := range c.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{}
}
} else if slices.Contains(peers, peerID) && c.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) {
for peerId, router := range networkRoutingPeers {
routes = append(routes, c.getNetworkResourcesRoutes(resource, peerId, router)...)
}
addedResourceRoute = true
}
if addedResourceRoute {
break
}
}
}
return isRoutingPeer, routes, allSourcePeers
}
func (c *NetworkMapComponents) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerID string, router *routerTypes.NetworkRouter) []*route.Route {
resourceAppliedPolicies := c.ResourcePoliciesMap[resource.ID]
var routes []*route.Route
if len(resourceAppliedPolicies) > 0 {
peerInfo := c.GetPeerInfo(peerID)
if peerInfo != nil {
routes = append(routes, c.networkResourceToRoute(resource, peerInfo, router))
}
}
return routes
}
func (c *NetworkMapComponents) networkResourceToRoute(resource *resourceTypes.NetworkResource, peer *nbpeer.Peer, router *routerTypes.NetworkRouter) *route.Route {
r := &route.Route{
ID: route.ID(resource.ID + ":" + peer.ID),
AccountID: resource.AccountID,
Peer: peer.Key,
PeerID: peer.ID,
Metric: router.Metric,
Masquerade: router.Masquerade,
Enabled: resource.Enabled,
KeepRoute: true,
NetID: route.NetID(resource.Name),
Description: resource.Description,
}
if resource.Type == resourceTypes.Host || resource.Type == resourceTypes.Subnet {
r.Network = resource.Prefix
r.NetworkType = route.IPv4Network
if resource.Prefix.Addr().Is6() {
r.NetworkType = route.IPv6Network
}
}
if resource.Type == resourceTypes.Domain {
domainList, err := domain.FromStringList([]string{resource.Domain})
if err == nil {
r.Domains = domainList
r.NetworkType = route.DomainNetwork
r.Network = netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
}
return r
}
func (c *NetworkMapComponents) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
if c.ValidatePostureChecksOnPeer(peerID, postureChecksIDs) {
dest = append(dest, peerID)
}
}
return dest
}
func (c *NetworkMapComponents) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0)
peerInfo := c.GetPeerInfo(peerID)
if peerInfo == nil {
return routesFirewallRules
}
for _, r := range routes {
if r.Peer != peerInfo.Key {
continue
}
resourceID := string(r.GetResourceID())
resourcePolicies := c.ResourcePoliciesMap[resourceID]
distributionPeers := c.getPoliciesSourcePeers(resourcePolicies)
rules := c.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers)
for _, rule := range rules {
if len(rule.SourceRanges) > 0 {
routesFirewallRules = append(routesFirewallRules, rule)
}
}
}
return routesFirewallRules
}
func (c *NetworkMapComponents) getPoliciesSourcePeers(policies []*Policy) map[string]struct{} {
sourcePeers := make(map[string]struct{})
for _, policy := range policies {
for _, rule := range policy.Rules {
for _, sourceGroup := range rule.Sources {
group := c.GetGroupInfo(sourceGroup)
if group == nil {
continue
}
for _, peer := range group.Peers {
sourcePeers[peer] = struct{}{}
}
}
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
sourcePeers[rule.SourceResource.ID] = struct{}{}
}
}
}
return sourcePeers
}
func (c *NetworkMapComponents) addNetworksRoutingPeers(
networkResourcesRoutes []*route.Route,
peerID string,
peersToConnect []*nbpeer.Peer,
expiredPeers []*nbpeer.Peer,
isRouter bool,
sourcePeers map[string]struct{},
) []*nbpeer.Peer {
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
for _, r := range networkResourcesRoutes {
networkRoutesPeers[r.PeerID] = struct{}{}
}
delete(sourcePeers, peerID)
delete(networkRoutesPeers, peerID)
for _, existingPeer := range peersToConnect {
delete(sourcePeers, existingPeer.ID)
delete(networkRoutesPeers, existingPeer.ID)
}
for _, expPeer := range expiredPeers {
delete(sourcePeers, expPeer.ID)
delete(networkRoutesPeers, expPeer.ID)
}
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
if isRouter {
for p := range sourcePeers {
missingPeers[p] = struct{}{}
}
}
for p := range networkRoutesPeers {
missingPeers[p] = struct{}{}
}
for p := range missingPeers {
peerInfo := c.GetPeerInfo(p)
if peerInfo == nil {
peerInfo = c.GetRouterPeerInfo(p)
}
if peerInfo != nil {
peersToConnect = append(peersToConnect, peerInfo)
}
}
return peersToConnect
}

View File

@@ -1,230 +0,0 @@
package types
import (
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route"
)
type GroupCompact struct {
Name string
PeerIndexes []int
}
type NetworkMapComponentsCompact struct {
PeerID string
Network *Network
AccountSettings *AccountSettingsInfo
DNSSettings *DNSSettings
CustomZoneDomain string
AllPeers []*nbpeer.Peer
PeerIndexes []int
RouterPeerIndexes []int
Groups map[string]*GroupCompact
AllPolicies []*Policy
PolicyIndexes []int
ResourcePoliciesMap map[string][]int
Routes []*route.Route
NameServerGroups []*nbdns.NameServerGroup
AllDNSRecords []nbdns.SimpleRecord
AccountZones []nbdns.CustomZone
RoutersMap map[string]map[string]*routerTypes.NetworkRouter
NetworkResources []*resourceTypes.NetworkResource
GroupIDToUserIDs map[string][]string
AllowedUserIDs map[string]struct{}
PostureFailedPeers map[string]map[string]struct{}
}
func (c *NetworkMapComponents) ToCompact() *NetworkMapComponentsCompact {
peerToIndex := make(map[string]int)
var allPeers []*nbpeer.Peer
for id, peer := range c.Peers {
if _, exists := peerToIndex[id]; !exists {
peerToIndex[id] = len(allPeers)
allPeers = append(allPeers, peer)
}
}
for id, peer := range c.RouterPeers {
if _, exists := peerToIndex[id]; !exists {
peerToIndex[id] = len(allPeers)
allPeers = append(allPeers, peer)
}
}
peerIndexes := make([]int, 0, len(c.Peers))
for id := range c.Peers {
peerIndexes = append(peerIndexes, peerToIndex[id])
}
routerPeerIndexes := make([]int, 0, len(c.RouterPeers))
for id := range c.RouterPeers {
routerPeerIndexes = append(routerPeerIndexes, peerToIndex[id])
}
groups := make(map[string]*GroupCompact, len(c.Groups))
for id, group := range c.Groups {
peerIdxs := make([]int, 0, len(group.Peers))
for _, peerID := range group.Peers {
if idx, ok := peerToIndex[peerID]; ok {
peerIdxs = append(peerIdxs, idx)
}
}
groups[id] = &GroupCompact{
Name: group.Name,
PeerIndexes: peerIdxs,
}
}
policyToIndex := make(map[*Policy]int)
var allPolicies []*Policy
for _, policy := range c.Policies {
if _, exists := policyToIndex[policy]; !exists {
policyToIndex[policy] = len(allPolicies)
allPolicies = append(allPolicies, policy)
}
}
for _, policies := range c.ResourcePoliciesMap {
for _, policy := range policies {
if _, exists := policyToIndex[policy]; !exists {
policyToIndex[policy] = len(allPolicies)
allPolicies = append(allPolicies, policy)
}
}
}
policyIndexes := make([]int, len(c.Policies))
for i, policy := range c.Policies {
policyIndexes[i] = policyToIndex[policy]
}
var resourcePoliciesMap map[string][]int
if len(c.ResourcePoliciesMap) > 0 {
resourcePoliciesMap = make(map[string][]int, len(c.ResourcePoliciesMap))
for resID, policies := range c.ResourcePoliciesMap {
indexes := make([]int, len(policies))
for i, policy := range policies {
indexes[i] = policyToIndex[policy]
}
resourcePoliciesMap[resID] = indexes
}
}
return &NetworkMapComponentsCompact{
PeerID: c.PeerID,
Network: c.Network,
AccountSettings: c.AccountSettings,
DNSSettings: c.DNSSettings,
CustomZoneDomain: c.CustomZoneDomain,
AllPeers: allPeers,
PeerIndexes: peerIndexes,
RouterPeerIndexes: routerPeerIndexes,
Groups: groups,
AllPolicies: allPolicies,
PolicyIndexes: policyIndexes,
ResourcePoliciesMap: resourcePoliciesMap,
Routes: c.Routes,
NameServerGroups: c.NameServerGroups,
AllDNSRecords: c.AllDNSRecords,
AccountZones: c.AccountZones,
RoutersMap: c.RoutersMap,
NetworkResources: c.NetworkResources,
GroupIDToUserIDs: c.GroupIDToUserIDs,
AllowedUserIDs: c.AllowedUserIDs,
PostureFailedPeers: c.PostureFailedPeers,
}
}
func (c *NetworkMapComponentsCompact) ToFull() *NetworkMapComponents {
peers := make(map[string]*nbpeer.Peer, len(c.PeerIndexes))
for _, idx := range c.PeerIndexes {
if idx >= 0 && idx < len(c.AllPeers) {
peer := c.AllPeers[idx]
peers[peer.ID] = peer
}
}
routerPeers := make(map[string]*nbpeer.Peer, len(c.RouterPeerIndexes))
for _, idx := range c.RouterPeerIndexes {
if idx >= 0 && idx < len(c.AllPeers) {
peer := c.AllPeers[idx]
routerPeers[peer.ID] = peer
}
}
groups := make(map[string]*Group, len(c.Groups))
for id, gc := range c.Groups {
peerIDs := make([]string, 0, len(gc.PeerIndexes))
for _, idx := range gc.PeerIndexes {
if idx >= 0 && idx < len(c.AllPeers) {
peerIDs = append(peerIDs, c.AllPeers[idx].ID)
}
}
groups[id] = &Group{
ID: id,
Name: gc.Name,
Peers: peerIDs,
}
}
policies := make([]*Policy, len(c.PolicyIndexes))
for i, idx := range c.PolicyIndexes {
if idx >= 0 && idx < len(c.AllPolicies) {
policies[i] = c.AllPolicies[idx]
}
}
var resourcePoliciesMap map[string][]*Policy
if len(c.ResourcePoliciesMap) > 0 {
resourcePoliciesMap = make(map[string][]*Policy, len(c.ResourcePoliciesMap))
for resID, indexes := range c.ResourcePoliciesMap {
pols := make([]*Policy, 0, len(indexes))
for _, idx := range indexes {
if idx >= 0 && idx < len(c.AllPolicies) {
pols = append(pols, c.AllPolicies[idx])
}
}
resourcePoliciesMap[resID] = pols
}
}
return &NetworkMapComponents{
PeerID: c.PeerID,
Network: c.Network,
AccountSettings: c.AccountSettings,
DNSSettings: c.DNSSettings,
CustomZoneDomain: c.CustomZoneDomain,
Peers: peers,
RouterPeers: routerPeers,
Groups: groups,
Policies: policies,
Routes: c.Routes,
NameServerGroups: c.NameServerGroups,
AllDNSRecords: c.AllDNSRecords,
AccountZones: c.AccountZones,
ResourcePoliciesMap: resourcePoliciesMap,
RoutersMap: c.RoutersMap,
NetworkResources: c.NetworkResources,
GroupIDToUserIDs: c.GroupIDToUserIDs,
AllowedUserIDs: c.AllowedUserIDs,
PostureFailedPeers: c.PostureFailedPeers,
}
}

View File

@@ -47,11 +47,6 @@ type Settings struct {
// NetworkRange is the custom network range for that account
NetworkRange netip.Prefix `gorm:"serializer:json"`
// PeerExposeEnabled enables or disables peer-initiated service expose
PeerExposeEnabled bool
// PeerExposeGroups list of peer group IDs allowed to expose services
PeerExposeGroups []string `gorm:"serializer:json"`
// Extra is a dictionary of Account settings
Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
@@ -85,8 +80,6 @@ func (s *Settings) Copy() *Settings {
PeerInactivityExpiration: s.PeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled,
PeerExposeEnabled: s.PeerExposeEnabled,
PeerExposeGroups: slices.Clone(s.PeerExposeGroups),
LazyConnectionEnabled: s.LazyConnectionEnabled,
DNSDomain: s.DNSDomain,
NetworkRange: s.NetworkRange,

Some files were not shown because too many files have changed in this diff Show More