Compare commits

...

21 Commits

Author SHA1 Message Date
pascal
95248e52f7 change default mtu 2026-03-05 09:57:01 +01:00
Maycon Santos
8e7b016be2 [management] Replace in-memory expose tracker with SQL-backed operations (#5494)
The expose tracker used sync.Map for in-memory TTL tracking of active expose sessions, which broke and lost all sessions on restart.

Replace with SQL-backed operations that reuse the existing meta_last_renewed_at column:

- Add store methods: RenewEphemeralService, GetExpiredEphemeralServices, CountEphemeralServicesByPeer, EphemeralServiceExists
- Move duplicate/limit checks inside a transaction with row-level locking (SELECT ... FOR UPDATE) to prevent concurrent bypass
- Reaper re-checks expiry under row lock to avoid deleting a just-renewed service and prevent duplicate event emission 
- Add composite index on (source, source_peer) for efficient queries
- Batch-limit and column-select the reaper query to avoid DB/GC spikes
- Filter out malformed rows with empty source_peer
2026-03-04 18:15:13 +01:00
Maycon Santos
9e01ea7aae [misc] Add ISSUE_TEMPLATE configuration file (#5500)
Add issue template config file  with support and troubleshooting links
2026-03-04 14:30:54 +01:00
hbzhost
cfc7ec8bb9 [client] Fix SSH JWT auth failure with Azure Entra ID iat backdating (#5471)
Increase DefaultJWTMaxTokenAge from 5 to 10 minutes to accommodate
identity providers like Azure Entra ID that backdate the iat claim
by up to 5 minutes, causing tokens to be immediately rejected.

Fixes #5449

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-04 14:11:14 +01:00
Misha Bragin
b3bbc0e5c6 Fix embedded IdP metrics to count local and generic OIDC users (#5498) 2026-03-04 12:34:11 +02:00
Pascal Fischer
d7c8e37ff4 [management] Store connected proxies in DB (#5472)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2026-03-03 18:39:46 +01:00
Zoltan Papp
05b66e73bc [client] Fix deadlock in route peer status watcher (#5489)
Wrap peerStateUpdate send in a nested select to prevent goroutine
blocking when the consumer has exited, which could fill the
subscription buffer and deadlock the Status mutex.
2026-03-03 13:50:46 +01:00
Jeremie Deray
01ceedac89 [client] Fix profile config directory permissions (#5457)
* fix user profile dir perm

* fix fileExists

* revert return var change

* fix anti-pattern
2026-03-03 13:48:51 +01:00
Misha Bragin
403babd433 [self-hosted] specify sql file location of auth, activity and main store (#5487) 2026-03-03 12:53:16 +02:00
Maycon Santos
47133031e5 [client] fix: client/Dockerfile to reduce vulnerabilities (#5217)
Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2026-03-03 08:44:08 +01:00
Pascal Fischer
82da606886 [management] Add explicit target delete on service removal (#5420) 2026-03-02 18:25:44 +01:00
Viktor Liu
bbe5ae2145 [client] Flush buffer immediately to support gprc (#5469) 2026-03-02 15:17:08 +01:00
Viktor Liu
0b21498b39 [client] Fix close of closed channel panic in ConnectClient retry loop (#5470) 2026-03-02 10:07:53 +01:00
Viktor Liu
0ca59535f1 [management] Add reverse proxy services REST client (#5454) 2026-02-28 13:04:58 +08:00
Misha Bragin
59c77d0658 [self-hosted] support embedded IDP postgres db (#5443)
* Add postgres config for embedded idp

Entire-Checkpoint: 9ace190c1067

* Rename idpStore to authStore

Entire-Checkpoint: 73a896c79614

* Fix review notes

Entire-Checkpoint: 6556783c0df3

* Don't accept pq port = 0

Entire-Checkpoint: 80d45e37782f

* Optimize configs

Entire-Checkpoint: 80d45e37782f

* Fix lint issues

Entire-Checkpoint: 3eec968003d1

* Fail fast on combined postgres config

Entire-Checkpoint: b17839d3d8c6

* Simplify management config method

Entire-Checkpoint: 0f083effa20e
2026-02-27 14:52:54 +01:00
shuuri-labs
333e045099 Lower socket auto-discovery log from Info to Debug (#5463)
The discovery message was printing on every CLI invocation, which is
noisy for users on distros using the systemd template.
2026-02-26 17:51:38 +01:00
Zoltan Papp
c2c4d9d336 [client] Fix Server mutex held across waitForUp in Up() (#5460)
Up() acquired s.mutex with a deferred unlock, then called waitForUp()
while still holding the lock. waitForUp() blocks for up to 50 seconds
waiting on clientRunningChan/clientGiveUpChan, starving all concurrent
gRPC calls that require the same mutex (Status, ListProfiles, etc.).

Replace the deferred unlock with explicit s.mutex.Unlock() on every
early-return path and immediately before waitForUp(), matching the
pattern already used by the clientRunning==true branch.
2026-02-26 16:47:02 +01:00
Bethuel Mmbaga
9a6a72e88e [management] Fix user update permission validation (#5441) 2026-02-24 22:47:41 +03:00
Bethuel Mmbaga
afe6d9fca4 [management] Prevent deletion of groups linked to flow groups (#5439) 2026-02-24 21:19:43 +03:00
shuuri-labs
ef82905526 [client] Add non default socket file discovery (#5425)
- Automatic Unix daemon address discovery: if the default socket is missing, the client can find and use a single available socket.
- Client startup now resolves daemon addresses more robustly while preserving non-Unix behavior.
2026-02-24 17:02:06 +01:00
Zoltan Papp
d18747e846 [client] Exclude Flow domain from caching to prevent TLS failures (#5433)
* Exclude Flow domain from caching to prevent TLS failures due to stale records.

* Fix test
2026-02-24 16:48:38 +01:00
86 changed files with 5140 additions and 1510 deletions

14
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,14 @@
blank_issues_enabled: true
contact_links:
- name: Community Support
url: https://forum.netbird.io/
about: Community support forum
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact us for support
- name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.2 FROM alpine:3.23.3
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \ RUN apk add --no-cache \
bash \ bash \

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
) )
@@ -80,6 +81,15 @@ var (
Short: "", Short: "",
Long: "", Long: "",
SilenceUsage: true, SilenceUsage: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(cmd.Root())
// Don't resolve for service commands — they create the socket, not connect to it.
if !isServiceCmd(cmd) {
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
}
return nil
},
} }
) )
@@ -386,7 +396,6 @@ func migrateToNetbird(oldPath, newPath string) bool {
} }
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) { func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
@@ -399,3 +408,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
return conn, nil return conn, nil
} }
// isServiceCmd returns true if cmd is the "service" command or a child of it.
func isServiceCmd(cmd *cobra.Command) bool {
for c := cmd; c != nil; c = c.Parent() {
if c.Name() == "service" {
return true
}
}
return false
}

View File

@@ -26,7 +26,7 @@ import (
) )
const ( const (
DefaultMTU = 1280 DefaultMTU = 1420
MinMTU = 576 MinMTU = 576
MaxMTU = 8192 MaxMTU = 8192
DefaultWgPort = 51820 DefaultWgPort = 51820

View File

@@ -331,8 +331,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil { if runningChan != nil {
close(runningChan) select {
runningChan = nil case <-runningChan:
default:
close(runningChan)
}
} }
<-engineCtx.Done() <-engineCtx.Done()

View File

@@ -0,0 +1,60 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
var scanDir = "/var/run/netbird"
// setScanDir overrides the scan directory (used by tests).
func setScanDir(dir string) {
scanDir = dir
}
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
// mismatch between the netbird@.service template (which places the socket under
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
func ResolveUnixDaemonAddr(addr string) string {
if !strings.HasPrefix(addr, "unix://") {
return addr
}
sockPath := strings.TrimPrefix(addr, "unix://")
if _, err := os.Stat(sockPath); err == nil {
return addr
}
entries, err := os.ReadDir(scanDir)
if err != nil {
return addr
}
var found []string
for _, e := range entries {
if e.IsDir() {
continue
}
if strings.HasSuffix(e.Name(), ".sock") {
found = append(found, filepath.Join(scanDir, e.Name()))
}
}
switch len(found) {
case 1:
resolved := "unix://" + found[0]
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
return resolved
case 0:
return addr
default:
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
return addr
}
}

View File

@@ -0,0 +1,8 @@
//go:build windows || ios || android
package daemonaddr
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
func ResolveUnixDaemonAddr(addr string) string {
return addr
}

View File

@@ -0,0 +1,121 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"testing"
)
// createSockFile creates a regular file with a .sock extension.
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
// sufficient and avoids Unix socket path-length limits on macOS.
func createSockFile(t *testing.T, path string) {
t.Helper()
if err := os.WriteFile(path, nil, 0o600); err != nil {
t.Fatalf("failed to create test sock file at %s: %v", path, err)
}
}
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
tmp := t.TempDir()
sock := filepath.Join(tmp, "netbird.sock")
createSockFile(t, sock)
addr := "unix://" + sock
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
tmp := t.TempDir()
// Default socket does not exist
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
// Create a scan dir with one socket
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
instanceSock := filepath.Join(sd, "main.sock")
createSockFile(t, instanceSock)
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
expected := "unix://" + instanceSock
if got != expected {
t.Errorf("expected %s, got %s", expected, got)
}
}
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
createSockFile(t, filepath.Join(sd, "main.sock"))
createSockFile(t, filepath.Join(sd, "other.sock"))
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
addr := "tcp://127.0.0.1:41731"
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
origScanDir := scanDir
setScanDir(filepath.Join(tmp, "nonexistent"))
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}

View File

@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
} }
} }
if serverDomains.Flow != "" { // Flow receiver domain is intentionally excluded from caching.
domains = append(domains, serverDomains.Flow) // Cloud providers may rotate the IP behind this domain; a stale cached record
} // causes TLS certificate verification failures on reconnect.
for _, stun := range serverDomains.Stuns { for _, stun := range serverDomains.Stuns {
if stun != "" { if stun != "" {

View File

@@ -391,7 +391,8 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
} }
assert.Len(t, resolver.GetCachedDomains(), 3) assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing) // Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
partialDomains := dnsconfig.ServerDomains{ partialDomains := dnsconfig.ServerDomains{
Flow: "github.com", Flow: "github.com",
} }
@@ -400,10 +401,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
t.Skipf("Skipping test due to DNS resolution failure: %v", err) t.Skipf("Skipping test due to DNS resolution failure: %v", err)
} }
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type") assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
finalDomains := resolver.GetCachedDomains() finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain") assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
domainStrings := make([]string, len(finalDomains)) domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains { for i, d := range finalDomains {
@@ -412,5 +413,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
assert.Contains(t, domainStrings, "example.org") assert.Contains(t, domainStrings, "example.org")
assert.Contains(t, domainStrings, "google.com") assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com") assert.Contains(t, domainStrings, "cloudflare.com")
assert.Contains(t, domainStrings, "github.com") assert.NotContains(t, domainStrings, "github.com")
} }

View File

@@ -198,7 +198,7 @@ func getConfigDirForUser(username string) (string, error) {
configDir := filepath.Join(DefaultConfigPathDir, username) configDir := filepath.Join(DefaultConfigPathDir, username)
if _, err := os.Stat(configDir); os.IsNotExist(err) { if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0600); err != nil { if err := os.MkdirAll(configDir, 0700); err != nil {
return "", err return "", err
} }
} }
@@ -206,9 +206,15 @@ func getConfigDirForUser(username string) (string, error) {
return configDir, nil return configDir, nil
} }
func fileExists(path string) bool { func fileExists(path string) (bool, error) {
_, err := os.Stat(path) _, err := os.Stat(path)
return !os.IsNotExist(err) if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
} }
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -635,7 +641,11 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
// UpdateConfig update existing configuration according to input configuration and return with the configuration // UpdateConfig update existing configuration according to input configuration and return with the configuration
func UpdateConfig(input ConfigInput) (*Config, error) { func UpdateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) { configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath) return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
} }
@@ -644,7 +654,11 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
// UpdateOrCreateConfig reads existing config or generates a new one // UpdateOrCreateConfig reads existing config or generates a new one
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) { configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
log.Infof("generating new config %s", input.ConfigPath) log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input) cfg, err := createNewConfig(input)
if err != nil { if err != nil {
@@ -657,7 +671,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if isPreSharedKeyHidden(input.PreSharedKey) { if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil input.PreSharedKey = nil
} }
err := util.EnforcePermission(input.ConfigPath) err = util.EnforcePermission(input.ConfigPath)
if err != nil { if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err) log.Errorf("failed to enforce permission on config dir: %v", err)
} }
@@ -784,7 +798,12 @@ func ReadConfig(configPath string) (*Config, error) {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func readConfig(configPath string, createIfMissing bool) (*Config, error) { func readConfig(configPath string, createIfMissing bool) (*Config, error) {
if fileExists(configPath) { configExists, err := fileExists(configPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if configExists {
err := util.EnforcePermission(configPath) err := util.EnforcePermission(configPath)
if err != nil { if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err) log.Errorf("failed to enforce permission on config dir: %v", err)
@@ -831,7 +850,11 @@ func DirectWriteOutConfig(path string, config *Config) error {
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes. // DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox). // Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) { func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) { configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
log.Infof("generating new config %s", input.ConfigPath) log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input) cfg, err := createNewConfig(input)
if err != nil { if err != nil {

View File

@@ -256,7 +256,11 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
} }
profPath := filepath.Join(configDir, profileName+".json") profPath := filepath.Join(configDir, profileName+".json")
if fileExists(profPath) { profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
return ErrProfileAlreadyExists return ErrProfileAlreadyExists
} }
@@ -285,7 +289,11 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName) return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
} }
profPath := filepath.Join(configDir, profileName+".json") profPath := filepath.Join(configDir, profileName+".json")
if !fileExists(profPath) { profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if !profileExists {
return ErrProfileNotFound return ErrProfileNotFound
} }

View File

@@ -20,7 +20,11 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er
} }
stateFile := filepath.Join(configDir, profileName+".state.json") stateFile := filepath.Join(configDir, profileName+".state.json")
if !fileExists(stateFile) { stateFileExists, err := fileExists(stateFile)
if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
}
if !stateFileExists {
return nil, errors.New("profile state file does not exist") return nil, errors.New("profile state file does not exist")
} }

View File

@@ -263,8 +263,14 @@ func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, pe
case <-closer: case <-closer:
return return
case routerStates := <-subscription.Events(): case routerStates := <-subscription.Events():
peerStateUpdate <- routerStates select {
log.Debugf("triggered route state update for Peer: %s", peerKey) case peerStateUpdate <- routerStates:
log.Debugf("triggered route state update for Peer: %s", peerKey)
case <-ctx.Done():
return
case <-closer:
return
}
} }
} }
} }

View File

@@ -641,8 +641,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
return s.waitForUp(callerCtx) return s.waitForUp(callerCtx)
} }
defer s.mutex.Unlock()
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
log.Warnf(errRestoreResidualState, err) log.Warnf(errRestoreResidualState, err)
} }
@@ -654,10 +652,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
// not in the progress or already successfully established connection. // not in the progress or already successfully established connection.
status, err := state.Status() status, err := state.Status()
if err != nil { if err != nil {
s.mutex.Unlock()
return nil, err return nil, err
} }
if status != internal.StatusIdle { if status != internal.StatusIdle {
s.mutex.Unlock()
return nil, fmt.Errorf("up already in progress: current status %s", status) return nil, fmt.Errorf("up already in progress: current status %s", status)
} }
@@ -674,17 +674,20 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.actCancel = cancel s.actCancel = cancel
if s.config == nil { if s.config == nil {
s.mutex.Unlock()
return nil, fmt.Errorf("config is not defined, please call login command first") return nil, fmt.Errorf("config is not defined, please call login command first")
} }
activeProf, err := s.profileManager.GetActiveProfileState() activeProf, err := s.profileManager.GetActiveProfileState()
if err != nil { if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile state: %v", err) log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err) return nil, fmt.Errorf("failed to get active profile state: %w", err)
} }
if msg != nil && msg.ProfileName != nil { if msg != nil && msg.ProfileName != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil { if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
s.mutex.Unlock()
log.Errorf("failed to switch profile: %v", err) log.Errorf("failed to switch profile: %v", err)
return nil, fmt.Errorf("failed to switch profile: %w", err) return nil, fmt.Errorf("failed to switch profile: %w", err)
} }
@@ -692,6 +695,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
activeProf, err = s.profileManager.GetActiveProfileState() activeProf, err = s.profileManager.GetActiveProfileState()
if err != nil { if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile state: %v", err) log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err) return nil, fmt.Errorf("failed to get active profile state: %w", err)
} }
@@ -700,6 +704,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
config, _, err := s.getConfig(activeProf) config, _, err := s.getConfig(activeProf)
if err != nil { if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile config: %v", err) log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err) return nil, fmt.Errorf("failed to get active profile config: %w", err)
} }
@@ -718,6 +723,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
} }
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan) go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
s.mutex.Unlock()
return s.waitForUp(callerCtx) return s.waitForUp(callerCtx)
} }

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
@@ -268,7 +269,7 @@ func getDefaultDaemonAddr() string {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
return DefaultDaemonAddrWindows return DefaultDaemonAddrWindows
} }
return DefaultDaemonAddr return daemonaddr.ResolveUnixDaemonAddr(DefaultDaemonAddr)
} }
// DialOptions contains options for SSH connections // DialOptions contains options for SSH connections

View File

@@ -46,8 +46,10 @@ const (
cmdSFTP = "<sftp>" cmdSFTP = "<sftp>"
cmdNonInteractive = "<idle>" cmdNonInteractive = "<idle>"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server.
DefaultJWTMaxTokenAge = 5 * 60 // Set to 10 minutes to accommodate identity providers like Azure Entra ID
// that backdate the iat claim by up to 5 minutes.
DefaultJWTMaxTokenAge = 10 * 60
) )
var ( var (

View File

@@ -7,6 +7,7 @@ import (
"net/netip" "net/netip"
"os" "os"
"path" "path"
"path/filepath"
"strings" "strings"
"time" "time"
@@ -71,6 +72,7 @@ type ServerConfig struct {
Auth AuthConfig `yaml:"auth"` Auth AuthConfig `yaml:"auth"`
Store StoreConfig `yaml:"store"` Store StoreConfig `yaml:"store"`
ActivityStore StoreConfig `yaml:"activityStore"` ActivityStore StoreConfig `yaml:"activityStore"`
AuthStore StoreConfig `yaml:"authStore"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"` ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
} }
@@ -171,7 +173,8 @@ type RelaysConfig struct {
type StoreConfig struct { type StoreConfig struct {
Engine string `yaml:"engine"` Engine string `yaml:"engine"`
EncryptionKey string `yaml:"encryptionKey"` EncryptionKey string `yaml:"encryptionKey"`
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
File string `yaml:"file"` // SQLite database file path (optional, defaults to dataDir)
} }
// ReverseProxyConfig contains reverse proxy settings // ReverseProxyConfig contains reverse proxy settings
@@ -533,6 +536,74 @@ func stripSignalProtocol(uri string) string {
return uri return uri
} }
func buildRelayConfig(relays RelaysConfig) (*nbconfig.Relay, error) {
var ttl time.Duration
if relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", relays.CredentialsTTL, err)
}
}
return &nbconfig.Relay{
Addresses: relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: relays.Secret,
}, nil
}
// buildEmbeddedIdPConfig builds the embedded IdP configuration.
// authStore overrides auth.storage when set.
func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.EmbeddedIdPConfig, error) {
authStorageType := mgmt.Auth.Storage.Type
authStorageDSN := c.Server.AuthStore.DSN
if c.Server.AuthStore.Engine != "" {
authStorageType = c.Server.AuthStore.Engine
}
if authStorageType == "" {
authStorageType = "sqlite3"
}
authStorageFile := ""
if authStorageType == "postgres" {
if authStorageDSN == "" {
return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres")
}
} else {
authStorageFile = path.Join(mgmt.DataDir, "idp.db")
if c.Server.AuthStore.File != "" {
authStorageFile = c.Server.AuthStore.File
if !filepath.IsAbs(authStorageFile) {
authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile)
}
}
}
cfg := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: authStorageType,
Config: idp.EmbeddedStorageTypeConfig{
File: authStorageFile,
DSN: authStorageDSN,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
cfg.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password,
}
}
return cfg, nil
}
// ToManagementConfig converts CombinedConfig to management server config // ToManagementConfig converts CombinedConfig to management server config
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) { func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
mgmt := c.Management mgmt := c.Management
@@ -551,19 +622,11 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
// Build relay config // Build relay config
var relayConfig *nbconfig.Relay var relayConfig *nbconfig.Relay
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" { if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
var ttl time.Duration relay, err := buildRelayConfig(mgmt.Relays)
if mgmt.Relays.CredentialsTTL != "" { if err != nil {
var err error return nil, err
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
}
}
relayConfig = &nbconfig.Relay{
Addresses: mgmt.Relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: mgmt.Relays.Secret,
} }
relayConfig = relay
} }
// Build signal config // Build signal config
@@ -599,31 +662,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
httpConfig := &nbconfig.HttpServerConfig{} httpConfig := &nbconfig.HttpServerConfig{}
// Build embedded IDP config (always enabled in combined server) // Build embedded IDP config (always enabled in combined server)
storageFile := mgmt.Auth.Storage.File embeddedIdP, err := c.buildEmbeddedIdPConfig(mgmt)
if storageFile == "" { if err != nil {
storageFile = path.Join(mgmt.DataDir, "idp.db") return nil, err
}
embeddedIdP := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: mgmt.Auth.Storage.Type,
Config: idp.EmbeddedStorageTypeConfig{
File: storageFile,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
embeddedIdP.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
}
} }
// Set HTTP config fields for embedded IDP // Set HTTP config fields for embedded IDP

View File

@@ -140,6 +140,9 @@ func initializeConfig() error {
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn) os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
} }
} }
if file := config.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
if engine := config.Server.ActivityStore.Engine; engine != "" { if engine := config.Server.ActivityStore.Engine; engine != "" {
engineLower := strings.ToLower(engine) engineLower := strings.ToLower(engine)
@@ -151,6 +154,9 @@ func initializeConfig() error {
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn) os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
} }
} }
if file := config.Server.ActivityStore.File; file != "" {
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
}
log.Infof("Starting combined NetBird server") log.Infof("Starting combined NetBird server")
logConfig(config) logConfig(config)

View File

@@ -42,6 +42,9 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn) os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
} }
} }
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
datadir := cfg.Management.DataDir datadir := cfg.Management.DataDir
engine := types.Engine(cfg.Management.Store.Engine) engine := types.Engine(cfg.Management.Store.Engine)

View File

@@ -103,11 +103,19 @@ server:
engine: "sqlite" # sqlite, postgres, or mysql engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql dsn: "" # Connection string for postgres or mysql
encryptionKey: "" encryptionKey: ""
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/store.db)
# Activity events store configuration (optional, defaults to sqlite in dataDir) # Activity events store configuration (optional, defaults to sqlite in dataDir)
# activityStore: # activityStore:
# engine: "sqlite" # sqlite or postgres # engine: "sqlite" # sqlite or postgres
# dsn: "" # Connection string for postgres # dsn: "" # Connection string for postgres
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/events.db)
# Auth (embedded IdP) store configuration (optional, defaults to sqlite3 in dataDir/idp.db)
# authStore:
# engine: "sqlite3" # sqlite3 or postgres
# dsn: "" # Connection string for postgres (e.g., "host=localhost port=5432 user=postgres password=postgres dbname=netbird_idp sslmode=disable")
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/idp.db)
# Reverse proxy settings (optional) # Reverse proxy settings (optional)
# reverseProxy: # reverseProxy:

View File

@@ -5,7 +5,10 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
"net/url"
"os" "os"
"strconv"
"strings"
"time" "time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@@ -195,11 +198,175 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
return nil, fmt.Errorf("sqlite3 storage requires 'file' config") return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
} }
return (&sql.SQLite3{File: file}).Open(logger) return (&sql.SQLite3{File: file}).Open(logger)
case "postgres":
dsn, _ := s.Config["dsn"].(string)
if dsn == "" {
return nil, fmt.Errorf("postgres storage requires 'dsn' config")
}
pg, err := parsePostgresDSN(dsn)
if err != nil {
return nil, fmt.Errorf("invalid postgres DSN: %w", err)
}
return pg.Open(logger)
default: default:
return nil, fmt.Errorf("unsupported storage type: %s", s.Type) return nil, fmt.Errorf("unsupported storage type: %s", s.Type)
} }
} }
// parsePostgresDSN parses a DSN into a sql.Postgres config.
// It accepts both URI format (postgres://user:pass@host:port/dbname?sslmode=disable)
// and libpq key=value format (host=localhost port=5432 dbname=mydb), including quoted values.
func parsePostgresDSN(dsn string) (*sql.Postgres, error) {
var params map[string]string
var err error
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
params, err = parsePostgresURI(dsn)
} else {
params, err = parsePostgresKeyValue(dsn)
}
if err != nil {
return nil, err
}
host := params["host"]
if host == "" {
host = "localhost"
}
var port uint16 = 5432
if p, ok := params["port"]; ok && p != "" {
v, err := strconv.ParseUint(p, 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid port %q: %w", p, err)
}
if v == 0 {
return nil, fmt.Errorf("invalid port %q: must be non-zero", p)
}
port = uint16(v)
}
dbname := params["dbname"]
if dbname == "" {
return nil, fmt.Errorf("dbname is required in DSN")
}
pg := &sql.Postgres{
NetworkDB: sql.NetworkDB{
Host: host,
Port: port,
Database: dbname,
User: params["user"],
Password: params["password"],
},
}
if sslMode := params["sslmode"]; sslMode != "" {
switch sslMode {
case "disable", "allow", "prefer", "require", "verify-ca", "verify-full":
pg.SSL.Mode = sslMode
default:
return nil, fmt.Errorf("unsupported sslmode %q: valid values are disable, allow, prefer, require, verify-ca, verify-full", sslMode)
}
}
return pg, nil
}
// parsePostgresURI parses a postgres:// or postgresql:// URI into parameter key-value pairs.
func parsePostgresURI(dsn string) (map[string]string, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid postgres URI: %w", err)
}
params := make(map[string]string)
if u.User != nil {
params["user"] = u.User.Username()
if p, ok := u.User.Password(); ok {
params["password"] = p
}
}
if u.Hostname() != "" {
params["host"] = u.Hostname()
}
if u.Port() != "" {
params["port"] = u.Port()
}
dbname := strings.TrimPrefix(u.Path, "/")
if dbname != "" {
params["dbname"] = dbname
}
for k, v := range u.Query() {
if len(v) > 0 {
params[k] = v[0]
}
}
return params, nil
}
// parsePostgresKeyValue parses a libpq key=value DSN string, handling single-quoted values
// (e.g., password='my pass' host=localhost).
func parsePostgresKeyValue(dsn string) (map[string]string, error) {
params := make(map[string]string)
s := strings.TrimSpace(dsn)
for s != "" {
eqIdx := strings.IndexByte(s, '=')
if eqIdx < 0 {
break
}
key := strings.TrimSpace(s[:eqIdx])
value, rest, err := parseDSNValue(s[eqIdx+1:])
if err != nil {
return nil, fmt.Errorf("%w for key %q", err, key)
}
params[key] = value
s = strings.TrimSpace(rest)
}
return params, nil
}
// parseDSNValue parses the next value from a libpq key=value string positioned after the '='.
// It returns the parsed value and the remaining unparsed string.
func parseDSNValue(s string) (value, rest string, err error) {
if len(s) > 0 && s[0] == '\'' {
return parseQuotedDSNValue(s[1:])
}
// Unquoted value: read until whitespace.
idx := strings.IndexAny(s, " \t\n")
if idx < 0 {
return s, "", nil
}
return s[:idx], s[idx:], nil
}
// parseQuotedDSNValue parses a single-quoted value starting after the opening quote.
// Libpq uses ” to represent a literal single quote inside quoted values.
func parseQuotedDSNValue(s string) (value, rest string, err error) {
var buf strings.Builder
for len(s) > 0 {
if s[0] == '\'' {
if len(s) > 1 && s[1] == '\'' {
buf.WriteByte('\'')
s = s[2:]
continue
}
return buf.String(), s[1:], nil
}
buf.WriteByte(s[0])
s = s[1:]
}
return "", "", fmt.Errorf("unterminated quoted value")
}
// Validate validates the configuration // Validate validates the configuration
func (c *YAMLConfig) Validate() error { func (c *YAMLConfig) Validate() error {
if c.Issuer == "" { if c.Issuer == "" {

View File

@@ -27,21 +27,21 @@ type store interface {
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
} }
type proxyURLProvider interface { type proxyManager interface {
GetConnectedProxyURLs() []string GetActiveClusterAddresses(ctx context.Context) ([]string, error)
} }
type Manager struct { type Manager struct {
store store store store
validator domain.Validator validator domain.Validator
proxyURLProvider proxyURLProvider proxyManager proxyManager
permissionsManager permissions.Manager permissionsManager permissions.Manager
} }
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager { func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
return Manager{ return Manager{
store: store, store: store,
proxyURLProvider: proxyURLProvider, proxyManager: proxyMgr,
validator: domain.Validator{ validator: domain.Validator{
Resolver: net.DefaultResolver, Resolver: net.DefaultResolver,
}, },
@@ -67,8 +67,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
// Add connected proxy clusters as free domains. // Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io"). // The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList := m.proxyURLAllowList() allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
log.WithFields(log.Fields{ if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
log.WithContext(ctx).WithFields(log.Fields{
"accountID": accountID, "accountID": accountID,
"proxyAllowList": allowList, "proxyAllowList": allowList,
}).Debug("getting domains with proxy allow list") }).Debug("getting domains with proxy allow list")
@@ -107,7 +111,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
} }
// Verify the target cluster is in the available clusters // Verify the target cluster is in the available clusters
allowList := m.proxyURLAllowList() allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
clusterValid := false clusterValid := false
for _, cluster := range allowList { for _, cluster := range allowList {
if cluster == targetCluster { if cluster == targetCluster {
@@ -221,25 +228,26 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
} }
} }
// GetClusterDomains returns a list of proxy cluster domains.
func (m Manager) GetClusterDomains() []string { func (m Manager) GetClusterDomains() []string {
return m.proxyURLAllowList() if m.proxyManager == nil {
} return nil
// proxyURLAllowList retrieves a list of currently connected proxies and
// their URLs
func (m Manager) proxyURLAllowList() []string {
var reverseProxyAddresses []string
if m.proxyURLProvider != nil {
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
} }
return reverseProxyAddresses addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background())
if err != nil {
return nil
}
return addresses
} }
// DeriveClusterFromDomain determines the proxy cluster for a given domain. // DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. // For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster. // For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) { func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList := m.proxyURLAllowList() allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
if len(allowList) == 0 { if len(allowList) == 0 {
return "", fmt.Errorf("no proxy clusters available") return "", fmt.Errorf("no proxy clusters available")
} }

View File

@@ -1,163 +0,0 @@
package manager
import (
"context"
"sync"
"time"
"github.com/netbirdio/netbird/shared/management/status"
log "github.com/sirupsen/logrus"
)
const (
exposeTTL = 90 * time.Second
exposeReapInterval = 30 * time.Second
maxExposesPerPeer = 10
)
type trackedExpose struct {
mu sync.Mutex
domain string
accountID string
peerID string
lastRenewed time.Time
expiring bool
}
type exposeTracker struct {
activeExposes sync.Map
exposeCreateMu sync.Mutex
manager *managerImpl
}
func exposeKey(peerID, domain string) string {
return peerID + ":" + domain
}
// TrackExposeIfAllowed atomically checks the per-peer limit and registers a new
// active expose session under the same lock. Returns (true, false) if the expose
// was already tracked (duplicate), (false, true) if tracking succeeded, and
// (false, false) if the peer has reached the limit.
func (t *exposeTracker) TrackExposeIfAllowed(peerID, domain, accountID string) (alreadyTracked, ok bool) {
t.exposeCreateMu.Lock()
defer t.exposeCreateMu.Unlock()
key := exposeKey(peerID, domain)
_, loaded := t.activeExposes.LoadOrStore(key, &trackedExpose{
domain: domain,
accountID: accountID,
peerID: peerID,
lastRenewed: time.Now(),
})
if loaded {
return true, false
}
if t.CountPeerExposes(peerID) > maxExposesPerPeer {
t.activeExposes.Delete(key)
return false, false
}
return false, true
}
// UntrackExpose removes an active expose session from tracking.
func (t *exposeTracker) UntrackExpose(peerID, domain string) {
t.activeExposes.Delete(exposeKey(peerID, domain))
}
// CountPeerExposes returns the number of active expose sessions for a peer.
func (t *exposeTracker) CountPeerExposes(peerID string) int {
count := 0
t.activeExposes.Range(func(_, val any) bool {
if expose := val.(*trackedExpose); expose.peerID == peerID {
count++
}
return true
})
return count
}
// MaxExposesPerPeer returns the maximum number of concurrent exposes allowed per peer.
func (t *exposeTracker) MaxExposesPerPeer() int {
return maxExposesPerPeer
}
// RenewTrackedExpose updates the in-memory lastRenewed timestamp for a tracked expose.
// Returns false if the expose is not tracked or is being reaped.
func (t *exposeTracker) RenewTrackedExpose(peerID, domain string) bool {
key := exposeKey(peerID, domain)
val, ok := t.activeExposes.Load(key)
if !ok {
return false
}
expose := val.(*trackedExpose)
expose.mu.Lock()
if expose.expiring {
expose.mu.Unlock()
return false
}
expose.lastRenewed = time.Now()
expose.mu.Unlock()
return true
}
// StopTrackedExpose removes an active expose session from tracking.
// Returns false if the expose was not tracked.
func (t *exposeTracker) StopTrackedExpose(peerID, domain string) bool {
key := exposeKey(peerID, domain)
_, ok := t.activeExposes.LoadAndDelete(key)
return ok
}
// StartExposeReaper starts a background goroutine that reaps expired expose sessions.
func (t *exposeTracker) StartExposeReaper(ctx context.Context) {
go func() {
ticker := time.NewTicker(exposeReapInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
t.reapExpiredExposes()
}
}
}()
}
func (t *exposeTracker) reapExpiredExposes() {
t.activeExposes.Range(func(key, val any) bool {
expose := val.(*trackedExpose)
expose.mu.Lock()
expired := time.Since(expose.lastRenewed) > exposeTTL
if expired {
expose.expiring = true
}
expose.mu.Unlock()
if !expired {
return true
}
log.Infof("reaping expired expose session for peer %s, domain %s", expose.peerID, expose.domain)
err := t.manager.deleteServiceFromPeer(context.Background(), expose.accountID, expose.peerID, expose.domain, true)
s, _ := status.FromError(err)
switch {
case err == nil:
t.activeExposes.Delete(key)
case s.ErrorType == status.NotFound:
log.Debugf("service %s was already deleted", expose.domain)
default:
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", expose.domain, err)
}
return true
})
}

View File

@@ -1,256 +0,0 @@
package manager
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
)
func TestExposeKey(t *testing.T) {
assert.Equal(t, "peer1:example.com", exposeKey("peer1", "example.com"))
assert.Equal(t, "peer2:other.com", exposeKey("peer2", "other.com"))
assert.NotEqual(t, exposeKey("peer1", "a.com"), exposeKey("peer1", "b.com"))
}
func TestTrackExposeIfAllowed(t *testing.T) {
t.Run("first track succeeds", func(t *testing.T) {
tracker := &exposeTracker{}
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.False(t, alreadyTracked, "first track should not be duplicate")
assert.True(t, ok, "first track should be allowed")
})
t.Run("duplicate track detected", func(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.True(t, alreadyTracked, "second track should be duplicate")
assert.False(t, ok)
})
t.Run("rejects when at limit", func(t *testing.T) {
tracker := &exposeTracker{}
for i := range maxExposesPerPeer {
_, ok := tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
assert.True(t, ok, "track %d should be allowed", i)
}
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "over-limit.com", "acct1")
assert.False(t, alreadyTracked)
assert.False(t, ok, "should reject when at limit")
})
t.Run("other peer unaffected by limit", func(t *testing.T) {
tracker := &exposeTracker{}
for i := range maxExposesPerPeer {
tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
}
_, ok := tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
assert.True(t, ok, "other peer should still be within limit")
})
}
func TestUntrackExpose(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.Equal(t, 1, tracker.CountPeerExposes("peer1"))
tracker.UntrackExpose("peer1", "a.com")
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
}
func TestCountPeerExposes(t *testing.T) {
tracker := &exposeTracker{}
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
tracker.TrackExposeIfAllowed("peer1", "b.com", "acct1")
tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
assert.Equal(t, 2, tracker.CountPeerExposes("peer1"), "peer1 should have 2 exposes")
assert.Equal(t, 1, tracker.CountPeerExposes("peer2"), "peer2 should have 1 expose")
assert.Equal(t, 0, tracker.CountPeerExposes("peer3"), "peer3 should have 0 exposes")
}
func TestMaxExposesPerPeer(t *testing.T) {
tracker := &exposeTracker{}
assert.Equal(t, maxExposesPerPeer, tracker.MaxExposesPerPeer())
}
func TestRenewTrackedExpose(t *testing.T) {
tracker := &exposeTracker{}
found := tracker.RenewTrackedExpose("peer1", "a.com")
assert.False(t, found, "should not find untracked expose")
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
found = tracker.RenewTrackedExpose("peer1", "a.com")
assert.True(t, found, "should find tracked expose")
}
func TestRenewTrackedExpose_RejectsExpiring(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
// Simulate reaper marking the expose as expiring
key := exposeKey("peer1", "a.com")
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.expiring = true
expose.mu.Unlock()
found := tracker.RenewTrackedExpose("peer1", "a.com")
assert.False(t, found, "should reject renewal when expiring")
}
func TestReapExpiredExposes(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
// Manually expire the tracked entry
key := exposeKey(testPeerID, resp.Domain)
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
// Add an active (non-expired) tracking entry
tracker.activeExposes.Store(exposeKey("peer1", "active.com"), &trackedExpose{
domain: "active.com",
accountID: testAccountID,
peerID: "peer1",
lastRenewed: time.Now(),
})
tracker.reapExpiredExposes()
_, exists := tracker.activeExposes.Load(key)
assert.False(t, exists, "expired expose should be removed")
_, exists = tracker.activeExposes.Load(exposeKey("peer1", "active.com"))
assert.True(t, exists, "active expose should remain")
}
func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
key := exposeKey(testPeerID, resp.Domain)
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
// Expire it
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
// Renew should succeed before reaping
assert.True(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should succeed before reaper runs")
// Re-expire and reap
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
tracker.reapExpiredExposes()
// Entry is deleted, renew returns false
assert.False(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should fail after reap")
}
func TestConcurrentTrackAndCount(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
require.NoError(t, err)
}
// Manually expire all tracked entries
tracker.activeExposes.Range(func(_, val any) bool {
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
return true
})
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
tracker.reapExpiredExposes()
}()
go func() {
defer wg.Done()
tracker.CountPeerExposes(testPeerID)
}()
wg.Wait()
assert.Equal(t, 0, tracker.CountPeerExposes(testPeerID), "all expired exposes should be reaped")
}
func TestTrackedExposeMutexProtectsLastRenewed(t *testing.T) {
expose := &trackedExpose{
lastRenewed: time.Now().Add(-1 * time.Hour),
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for range 100 {
expose.mu.Lock()
expose.lastRenewed = time.Now()
expose.mu.Unlock()
}
}()
go func() {
defer wg.Done()
for range 100 {
expose.mu.Lock()
_ = time.Since(expose.lastRenewed)
expose.mu.Unlock()
}
}()
wg.Wait()
expose.mu.Lock()
require.False(t, expose.lastRenewed.IsZero(), "lastRenewed should not be zero after concurrent access")
expose.mu.Unlock()
}

View File

@@ -0,0 +1,36 @@
package proxy
//go:generate go run github.com/golang/mock/mockgen -package proxy -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"time"
"github.com/netbirdio/netbird/shared/management/proto"
)
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
}
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
type OIDCValidationConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
// Controller is responsible for managing proxy clusters and routing service updates.
type Controller interface {
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
GetOIDCValidationConfig() OIDCValidationConfig
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
}

View File

@@ -0,0 +1,88 @@
package manager
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/shared/management/proto"
)
// GRPCController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
type GRPCController struct {
proxyGRPCServer *nbgrpc.ProxyServiceServer
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
metrics *metrics
}
// NewGRPCController creates a new GRPCController.
func NewGRPCController(proxyGRPCServer *nbgrpc.ProxyServiceServer, meter metric.Meter) (*GRPCController, error) {
m, err := newMetrics(meter)
if err != nil {
return nil, err
}
return &GRPCController{
proxyGRPCServer: proxyGRPCServer,
metrics: m,
}, nil
}
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
func (c *GRPCController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
c.metrics.IncrementServiceUpdateSendCount(clusterAddr)
}
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
func (c *GRPCController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return c.proxyGRPCServer.GetOIDCValidationConfig()
}
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
func (c *GRPCController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
c.metrics.IncrementProxyConnectionCount(clusterAddr)
return nil
}
// UnregisterProxyFromCluster removes a proxy from a cluster.
func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
c.metrics.DecrementProxyConnectionCount(clusterAddr)
}
return nil
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)
if !ok {
return nil
}
var proxies []string
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxies = append(proxies, key.(string))
return true
})
return proxies
}

View File

@@ -0,0 +1,115 @@
package manager
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// store defines the interface for proxy persistence operations
type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
}
// Manager handles all proxy operations
type Manager struct {
store store
metrics *metrics
}
// NewManager creates a new proxy Manager
func NewManager(store store, meter metric.Meter) (*Manager, error) {
m, err := newMetrics(meter)
if err != nil {
return nil, err
}
return &Manager{
store: store,
metrics: m,
}, nil
}
// Connect registers a new proxy connection in the database
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
"clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return nil
}
// Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
}).Info("proxy disconnected")
return nil
}
// Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
}
m.metrics.IncrementProxyHeartbeatCount()
return nil
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
return addresses, nil
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}

View File

@@ -0,0 +1,74 @@
package manager
import (
"context"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
type metrics struct {
proxyConnectionCount metric.Int64UpDownCounter
serviceUpdateSendCount metric.Int64Counter
proxyHeartbeatCount metric.Int64Counter
}
func newMetrics(meter metric.Meter) (*metrics, error) {
proxyConnectionCount, err := meter.Int64UpDownCounter(
"management_proxy_connection_count",
metric.WithDescription("Number of active proxy connections"),
metric.WithUnit("{connection}"),
)
if err != nil {
return nil, err
}
serviceUpdateSendCount, err := meter.Int64Counter(
"management_proxy_service_update_send_count",
metric.WithDescription("Total number of service updates sent to proxies"),
metric.WithUnit("{update}"),
)
if err != nil {
return nil, err
}
proxyHeartbeatCount, err := meter.Int64Counter(
"management_proxy_heartbeat_count",
metric.WithDescription("Total number of proxy heartbeats received"),
metric.WithUnit("{heartbeat}"),
)
if err != nil {
return nil, err
}
return &metrics{
proxyConnectionCount: proxyConnectionCount,
serviceUpdateSendCount: serviceUpdateSendCount,
proxyHeartbeatCount: proxyHeartbeatCount,
}, nil
}
func (m *metrics) IncrementProxyConnectionCount(clusterAddr string) {
m.proxyConnectionCount.Add(context.Background(), 1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) DecrementProxyConnectionCount(clusterAddr string) {
m.proxyConnectionCount.Add(context.Background(), -1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) IncrementServiceUpdateSendCount(clusterAddr string) {
m.serviceUpdateSendCount.Add(context.Background(), 1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) IncrementProxyHeartbeatCount() {
m.proxyHeartbeatCount.Add(context.Background(), 1)
}

View File

@@ -0,0 +1,199 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./manager.go
// Package proxy is a generated GoMock package.
package proxy
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
proto "github.com/netbirdio/netbird/shared/management/proto"
)
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CleanupStale mocks base method.
func (m *MockManager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupStale", ctx, inactivityDuration)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupStale indicates an expected call of CleanupStale.
func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error)
return ret0
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
}
// Disconnect mocks base method.
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// Disconnect indicates an expected call of Disconnect.
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
}
// GetActiveClusterAddresses mocks base method.
func (m *MockManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusterAddresses", ctx)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusterAddresses indicates an expected call of GetActiveClusterAddresses.
func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// Heartbeat indicates an expected call of Heartbeat.
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
}
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller
recorder *MockControllerMockRecorder
}
// MockControllerMockRecorder is the mock recorder for MockController.
type MockControllerMockRecorder struct {
mock *MockController
}
// NewMockController creates a new mock instance.
func NewMockController(ctrl *gomock.Controller) *MockController {
mock := &MockController{ctrl: ctrl}
mock.recorder = &MockControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockController) EXPECT() *MockControllerMockRecorder {
return m.recorder
}
// GetOIDCValidationConfig mocks base method.
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
ret0, _ := ret[0].(OIDCValidationConfig)
return ret0
}
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig))
}
// GetProxiesForCluster mocks base method.
func (m *MockController) GetProxiesForCluster(clusterAddr string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
ret0, _ := ret[0].([]string)
return ret0
}
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
func (mr *MockControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockController)(nil).GetProxiesForCluster), clusterAddr)
}
// RegisterProxyToCluster mocks base method.
func (m *MockController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
func (mr *MockControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
}
// SendServiceUpdateToCluster mocks base method.
func (m *MockController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SendServiceUpdateToCluster", ctx, accountID, update, clusterAddr)
}
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
func (mr *MockControllerMockRecorder) SendServiceUpdateToCluster(ctx, accountID, update, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockController)(nil).SendServiceUpdateToCluster), ctx, accountID, update, clusterAddr)
}
// UnregisterProxyFromCluster mocks base method.
func (m *MockController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
func (mr *MockControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
}

View File

@@ -0,0 +1,20 @@
package proxy
import "time"
// Proxy represents a reverse proxy instance
type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (Proxy) TableName() string {
return "proxies"
}

View File

@@ -1,6 +1,6 @@
package reverseproxy package service
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod //go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import ( import (
"context" "context"
@@ -14,7 +14,7 @@ type Manager interface {
DeleteService(ctx context.Context, accountID, userID, serviceID string) error DeleteService(ctx context.Context, accountID, userID, serviceID string) error
DeleteAllServices(ctx context.Context, accountID, userID string) error DeleteAllServices(ctx context.Context, accountID, userID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error) GetGlobalServices(ctx context.Context) ([]*Service, error)

View File

@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go // Source: ./interface.go
// Package reverseproxy is a generated GoMock package. // Package service is a generated GoMock package.
package reverseproxy package service
import ( import (
context "context" context "context"
@@ -239,7 +239,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
} }
// SetStatus mocks base method. // SetStatus mocks base method.
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error { func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status) ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)

View File

@@ -6,10 +6,10 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
@@ -17,11 +17,11 @@ import (
) )
type handler struct { type handler struct {
manager reverseproxy.Manager manager rpservice.Manager
} }
// RegisterEndpoints registers all service HTTP endpoints. // RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{ h := &handler{
manager: manager, manager: manager,
} }
@@ -72,7 +72,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
return return
} }
service := new(reverseproxy.Service) service := new(rpservice.Service)
service.FromAPIRequest(&req, userAuth.AccountId) service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil { if err = service.Validate(); err != nil {
@@ -130,7 +130,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
return return
} }
service := new(reverseproxy.Service) service := new(rpservice.Service)
service.ID = serviceID service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId) service.FromAPIRequest(&req, userAuth.AccountId)

View File

@@ -0,0 +1,65 @@
package manager
import (
"context"
"math/rand/v2"
"time"
"github.com/netbirdio/netbird/shared/management/status"
log "github.com/sirupsen/logrus"
)
const (
exposeTTL = 90 * time.Second
exposeReapInterval = 30 * time.Second
maxExposesPerPeer = 10
exposeReapBatch = 100
)
type exposeReaper struct {
manager *Manager
}
// StartExposeReaper starts a background goroutine that reaps expired ephemeral services from the DB.
func (r *exposeReaper) StartExposeReaper(ctx context.Context) {
go func() {
// start with a random delay
rn := rand.IntN(10)
time.Sleep(time.Duration(rn) * time.Second)
ticker := time.NewTicker(exposeReapInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
r.reapExpiredExposes(ctx)
}
}
}()
}
func (r *exposeReaper) reapExpiredExposes(ctx context.Context) {
expired, err := r.manager.store.GetExpiredEphemeralServices(ctx, exposeTTL, exposeReapBatch)
if err != nil {
log.Errorf("failed to get expired ephemeral services: %v", err)
return
}
for _, svc := range expired {
log.Infof("reaping expired expose session for peer %s, domain %s", svc.SourcePeer, svc.Domain)
err := r.manager.deleteExpiredPeerService(ctx, svc.AccountID, svc.SourcePeer, svc.ID)
if err == nil {
continue
}
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
log.Debugf("service %s was already deleted by another instance", svc.Domain)
} else {
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", svc.Domain, err)
}
}
}

View File

@@ -0,0 +1,208 @@
package manager
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/store"
)
func TestReapExpiredExposes(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
// Manually expire the service by backdating meta_last_renewed_at
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Create a non-expired service
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8081,
Protocol: "http",
})
require.NoError(t, err)
mgr.exposeReaper.reapExpiredExposes(ctx)
// Expired service should be deleted
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.Error(t, err, "expired service should be deleted")
// Non-expired service should remain
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp2.Domain)
require.NoError(t, err, "active service should remain")
}
func TestReapAlreadyDeletedService(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Delete the service before reaping
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
// Reaping should handle the already-deleted service gracefully
mgr.exposeReaper.reapExpiredExposes(ctx)
}
func TestConcurrentReapAndRenew(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
require.NoError(t, err)
}
// Expire all services
services, err := testStore.GetAccountServices(ctx, store.LockingStrengthNone, testAccountID)
require.NoError(t, err)
for _, svc := range services {
if svc.Source == rpservice.SourceEphemeral {
expireEphemeralService(t, testStore, testAccountID, svc.Domain)
}
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
mgr.exposeReaper.reapExpiredExposes(ctx)
}()
go func() {
defer wg.Done()
_, _ = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
}()
wg.Wait()
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count, "all expired services should be reaped")
}
func TestRenewEphemeralService(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
t.Run("renew succeeds for active service", func(t *testing.T) {
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8082,
Protocol: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
})
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
require.Error(t, err)
assert.Contains(t, err.Error(), "no active expose session")
})
}
func TestCountAndExistsEphemeralServices(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8083,
Protocol: "http",
})
require.NoError(t, err)
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
exists, err := mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
assert.True(t, exists, "service should exist")
exists, err = mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, "no-such.domain")
require.NoError(t, err)
assert.False(t, exists, "non-existent service should not exist")
}
func TestMaxExposesPerPeerEnforced(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
for i := range maxExposesPerPeer {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8090 + i,
Protocol: "http",
})
require.NoError(t, err, "expose %d should succeed", i)
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9999,
Protocol: "http",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
}
func TestReapSkipsRenewedService(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8086,
Protocol: "http",
})
require.NoError(t, err)
// Expire the service
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Renew it before the reaper runs
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
// Reaper should skip it because the re-check sees a fresh timestamp
mgr.exposeReaper.reapExpiredExposes(ctx)
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err, "renewed service should survive reaping")
}
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
t.Helper()
svc, err := s.GetServiceByDomain(context.Background(), accountID, domain)
require.NoError(t, err)
expired := time.Now().Add(-2 * exposeTTL)
svc.Meta.LastRenewedAt = &expired
err = s.UpdateService(context.Background(), svc)
require.NoError(t, err)
}

View File

@@ -4,24 +4,22 @@ import (
"context" "context"
"fmt" "fmt"
"math/rand/v2" "math/rand/v2"
"slices"
"time" "time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"slices" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -33,36 +31,34 @@ type ClusterDeriver interface {
GetClusterDomains() []string GetClusterDomains() []string
} }
type managerImpl struct { type Manager struct {
store store.Store store store.Store
accountManager account.Manager accountManager account.Manager
permissionsManager permissions.Manager permissionsManager permissions.Manager
settingsManager settings.Manager proxyController proxy.Controller
proxyGRPCServer *nbgrpc.ProxyServiceServer
clusterDeriver ClusterDeriver clusterDeriver ClusterDeriver
exposeTracker *exposeTracker exposeReaper *exposeReaper
} }
// NewManager creates a new service manager. // NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, settingsManager settings.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager { func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
mgr := &managerImpl{ mgr := &Manager{
store: store, store: store,
accountManager: accountManager, accountManager: accountManager,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
settingsManager: settingsManager, proxyController: proxyController,
proxyGRPCServer: proxyGRPCServer,
clusterDeriver: clusterDeriver, clusterDeriver: clusterDeriver,
} }
mgr.exposeTracker = &exposeTracker{manager: mgr} mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr return mgr
} }
// StartExposeReaper delegates to the expose tracker. // StartExposeReaper starts the background goroutine that reaps expired ephemeral services.
func (m *managerImpl) StartExposeReaper(ctx context.Context) { func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeTracker.StartExposeReaper(ctx) m.exposeReaper.StartExposeReaper(ctx)
} }
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil { if err != nil {
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
@@ -86,34 +82,34 @@ func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID stri
return services, nil return services, nil
} }
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error { func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
for _, target := range service.Targets { for _, target := range s.Targets {
switch target.TargetType { switch target.TargetType {
case reverseproxy.TargetTypePeer: case service.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err) log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder target.Host = unknownHostPlaceholder
continue continue
} }
target.Host = peer.IP.String() target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost: case service.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err) log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder target.Host = unknownHostPlaceholder
continue continue
} }
target.Host = resource.Prefix.Addr().String() target.Host = resource.Prefix.Addr().String()
case reverseproxy.TargetTypeDomain: case service.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err) log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder target.Host = unknownHostPlaceholder
continue continue
} }
target.Host = resource.Domain target.Host = resource.Domain
case reverseproxy.TargetTypeSubnet: case service.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource // For subnets we do not do any lookups on the resource
default: default:
return fmt.Errorf("unknown target type: %s", target.TargetType) return fmt.Errorf("unknown target type: %s", target.TargetType)
@@ -122,7 +118,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string,
return nil return nil
} }
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) { func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil { if err != nil {
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
@@ -143,7 +139,7 @@ func (m *managerImpl) GetService(ctx context.Context, accountID, userID, service
return service, nil return service, nil
} }
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil { if err != nil {
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
@@ -152,29 +148,29 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil { if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
return nil, err return nil, err
} }
if err := m.persistNewService(ctx, accountID, service); err != nil { if err := m.persistNewService(ctx, accountID, s); err != nil {
return nil, err return nil, err
} }
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta()) m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service) err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
} }
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "") m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID) m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil return s, nil
} }
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error { func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
if m.clusterDeriver != nil { if m.clusterDeriver != nil {
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil { if err != nil {
@@ -201,7 +197,7 @@ func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID
return nil return nil
} }
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error { func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil { if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
return err return err
@@ -219,7 +215,53 @@ func (m *managerImpl) persistNewService(ctx context.Context, accountID string, s
}) })
} }
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error { // persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Lock the peer row to serialize concurrent creates for the same peer.
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
// table returns no rows and acquires no locks, allowing concurrent inserts
// to bypass the per-peer limit.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
if err := m.checkDomainAvailable(ctx, transaction, accountID, svc.Domain, ""); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
return nil
})
}
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain) existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
if err != nil { if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
@@ -235,7 +277,7 @@ func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction stor
return nil return nil
} }
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil { if err != nil {
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
@@ -259,7 +301,7 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
} }
m.sendServiceUpdateNotifications(service, updateInfo) m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
m.accountManager.UpdateAccountPeers(ctx, accountID) m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil return service, nil
@@ -271,7 +313,7 @@ type serviceUpdateInfo struct {
serviceEnabledChanged bool serviceEnabledChanged bool
} }
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) { func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
var updateInfo serviceUpdateInfo var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -309,7 +351,7 @@ func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string
return &updateInfo, err return &updateInfo, err
} }
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error { func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil { if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
return err return err
} }
@@ -326,7 +368,7 @@ func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.
return nil return nil
} }
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) { func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled && if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" { service.Auth.PasswordAuth.Password == "" {
@@ -340,54 +382,40 @@ func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reve
} }
} }
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) { func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
service.Meta = existingService.Meta service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey service.SessionPublicKey = existingService.SessionPublicKey
} }
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) { func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyController.GetOIDCValidationConfig()
switch { switch {
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster: case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
m.sendServiceUpdate(service, reverseproxy.Delete, updateInfo.oldCluster, "") m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "") m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
case !service.Enabled && updateInfo.serviceEnabledChanged: case !s.Enabled && updateInfo.serviceEnabledChanged:
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "") m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
case service.Enabled && updateInfo.serviceEnabledChanged: case s.Enabled && updateInfo.serviceEnabledChanged:
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "") m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
default: default:
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "") m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
} }
} }
func (m *managerImpl) sendServiceUpdate(service *reverseproxy.Service, operation reverseproxy.Operation, cluster, oldService string) {
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
mapping := service.ToProtoMapping(operation, oldService, oidcCfg)
m.sendMappingsToCluster([]*proto.ProxyMapping{mapping}, cluster)
}
func (m *managerImpl) sendMappingsToCluster(mappings []*proto.ProxyMapping, cluster string) {
if len(mappings) == 0 {
return
}
update := &proto.GetMappingUpdateResponse{
Mapping: mappings,
}
m.proxyGRPCServer.SendServiceUpdateToCluster(update, cluster)
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account. // validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error { func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
for _, target := range targets { for _, target := range targets {
switch target.TargetType { switch target.TargetType {
case reverseproxy.TargetTypePeer: case service.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
} }
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
} }
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain: case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
@@ -399,7 +427,7 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil return nil
} }
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil { if err != nil {
return status.NewPermissionValidationError(err) return status.NewPermissionValidationError(err)
@@ -408,14 +436,18 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
var service *reverseproxy.Service var s *service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil { if err != nil {
return err return err
} }
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil { if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete service: %w", err) return fmt.Errorf("failed to delete service: %w", err)
} }
@@ -426,20 +458,16 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return err return err
} }
if service.Source == reverseproxy.SourceEphemeral { m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain)
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta()) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
m.accountManager.UpdateAccountPeers(ctx, accountID) m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }
func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID string) error { func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil { if err != nil {
return status.NewPermissionValidationError(err) return status.NewPermissionValidationError(err)
@@ -448,16 +476,16 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
var services []*reverseproxy.Service var services []*service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error var err error
services, err = transaction.GetServicesByAccountID(ctx, store.LockingStrengthUpdate, accountID) services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
return err return err
} }
for _, service := range services { for _, svc := range services {
if err = transaction.DeleteService(ctx, accountID, service.ID); err != nil { if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service: %w", err) return fmt.Errorf("failed to delete service: %w", err)
} }
} }
@@ -468,20 +496,11 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
return err return err
} }
clusterMappings := make(map[string][]*proto.ProxyMapping) oidcCfg := m.proxyController.GetOIDCValidationConfig()
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
for _, service := range services { for _, svc := range services {
if service.Source == reverseproxy.SourceEphemeral { m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, service.EventMeta())
mapping := service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg)
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
}
for cluster, mappings := range clusterMappings {
m.sendMappingsToCluster(mappings, cluster)
} }
m.accountManager.UpdateAccountPeers(ctx, accountID) m.accountManager.UpdateAccountPeers(ctx, accountID)
@@ -491,7 +510,7 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time. // SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued. // Call this when receiving a gRPC notification that the certificate was issued.
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil { if err != nil {
@@ -510,7 +529,7 @@ func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, ser
} }
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.) // SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error { func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil { if err != nil {
@@ -527,50 +546,42 @@ func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string
}) })
} }
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error { func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get service: %w", err) return fmt.Errorf("failed to get service: %w", err)
} }
err = m.replaceHostByLookup(ctx, accountID, service) err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil { if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
} }
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "") m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID) m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error { func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get services: %w", err) return fmt.Errorf("failed to get services: %w", err)
} }
clusterMappings := make(map[string][]*proto.ProxyMapping) for _, s := range services {
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() err = m.replaceHostByLookup(ctx, accountID, s)
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil { if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
} }
mapping := service.ToProtoMapping(reverseproxy.Update, "", oidcCfg) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
}
for cluster, mappings := range clusterMappings {
m.sendMappingsToCluster(mappings, cluster)
} }
return nil return nil
} }
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone) services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err) return nil, fmt.Errorf("failed to get services: %w", err)
@@ -586,7 +597,7 @@ func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Se
return services, nil return services, nil
} }
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) { func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err) return nil, fmt.Errorf("failed to get service: %w", err)
@@ -600,7 +611,7 @@ func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID s
return service, nil return service, nil
} }
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err) return nil, fmt.Errorf("failed to get services: %w", err)
@@ -616,7 +627,7 @@ func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string)
return services, nil return services, nil
} }
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
@@ -634,7 +645,7 @@ func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID stri
// validateExposePermission checks whether the peer is allowed to use the expose feature. // validateExposePermission checks whether the peer is allowed to use the expose feature.
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group. // It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, peerID string) error { func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err) log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
@@ -667,7 +678,7 @@ func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, p
// CreateServiceFromPeer creates a service initiated by a peer expose request. // CreateServiceFromPeer creates a service initiated by a peer expose request.
// It validates the request, checks expose permissions, enforces the per-peer limit, // It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping. // creates the service, and tracks it for TTL-based reaping.
func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
if err := req.Validate(); err != nil { if err := req.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err) return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
} }
@@ -676,31 +687,31 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
return nil, err return nil, err
} }
serviceName, err := reverseproxy.GenerateExposeName(req.NamePrefix) serviceName, err := service.GenerateExposeName(req.NamePrefix)
if err != nil { if err != nil {
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err) return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
} }
service := req.ToService(accountID, peerID, serviceName) svc := req.ToService(accountID, peerID, serviceName)
service.Source = reverseproxy.SourceEphemeral svc.Source = service.SourceEphemeral
if service.Domain == "" { if svc.Domain == "" {
domain, err := m.buildRandomDomain(service.Name) domain, err := m.buildRandomDomain(svc.Name)
if err != nil { if err != nil {
return nil, fmt.Errorf("build random domain for service %s: %w", service.Name, err) return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
} }
service.Domain = domain svc.Domain = domain
} }
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled { if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled {
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, service.Auth.BearerAuth.DistributionGroups) groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups)
if err != nil { if err != nil {
return nil, fmt.Errorf("get group ids for service %s: %w", service.Name, err) return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err)
} }
service.Auth.BearerAuth.DistributionGroups = groupIDs svc.Auth.BearerAuth.DistributionGroups = groupIDs
} }
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil { if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil {
return nil, err return nil, err
} }
@@ -709,46 +720,33 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
return nil, err return nil, err
} }
now := time.Now() svc.SourcePeer = peerID
service.Meta.LastRenewedAt = &now
service.SourcePeer = peerID
if err := m.persistNewService(ctx, accountID, service); err != nil { now := time.Now()
svc.Meta.LastRenewedAt = &now
if err := m.persistNewEphemeralService(ctx, accountID, peerID, svc); err != nil {
return nil, err return nil, err
} }
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, service.Domain, accountID) meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
if alreadyTracked { m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta)
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", service.Domain, err) if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil {
} return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err)
return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
if !allowed {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", service.Domain, err)
}
return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
} }
meta := addPeerInfoToEventMeta(service.EventMeta(), peer) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.StoreEvent(ctx, peerID, service.ID, accountID, activity.PeerServiceExposed, meta)
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
return nil, fmt.Errorf("replace host by lookup for service %s: %w", service.ID, err)
}
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
m.accountManager.UpdateAccountPeers(ctx, accountID) m.accountManager.UpdateAccountPeers(ctx, accountID)
return &reverseproxy.ExposeServiceResponse{ return &service.ExposeServiceResponse{
ServiceName: service.Name, ServiceName: svc.Name,
ServiceURL: "https://" + service.Domain, ServiceURL: "https://" + svc.Domain,
Domain: service.Domain, Domain: svc.Domain,
}, nil }, nil
} }
func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) { func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
if len(groupNames) == 0 { if len(groupNames) == 0 {
return []string{}, fmt.Errorf("no group names provided") return []string{}, fmt.Errorf("no group names provided")
} }
@@ -763,7 +761,7 @@ func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string
return groupIDs, nil return groupIDs, nil
} }
func (m *managerImpl) buildRandomDomain(name string) (string, error) { func (m *Manager) buildRandomDomain(name string) (string, error) {
if m.clusterDeriver == nil { if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain") return "", fmt.Errorf("unable to get random domain")
} }
@@ -776,33 +774,24 @@ func (m *managerImpl) buildRandomDomain(name string) (string, error) {
return domain, nil return domain, nil
} }
// RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session. // RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
// Returns an error if the expose is not actively tracked. func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error { return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
if !m.exposeTracker.RenewTrackedExpose(peerID, domain) {
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
}
return nil
} }
// StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service. // StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil { if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err) log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
return err return err
} }
if !m.exposeTracker.StopTrackedExpose(peerID, domain) {
log.WithContext(ctx).Warnf("expose tracker entry for domain %s already removed; service was deleted", domain)
}
return nil return nil
} }
// deleteServiceFromPeer deletes a peer-initiated service identified by domain. // deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed. // When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error { func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
service, err := m.lookupPeerService(ctx, accountID, peerID, domain) svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil { if err != nil {
return err return err
} }
@@ -811,41 +800,41 @@ func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peer
if expired { if expired {
activityCode = activity.PeerServiceExposeExpired activityCode = activity.PeerServiceExposeExpired
} }
return m.deletePeerService(ctx, accountID, peerID, service.ID, activityCode) return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
} }
// lookupPeerService finds a peer-initiated service by domain and validates ownership. // lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *managerImpl) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*reverseproxy.Service, error) { func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
service, err := m.store.GetServiceByDomain(ctx, accountID, domain) svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if service.Source != reverseproxy.SourceEphemeral { if svc.Source != service.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose") return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
} }
if service.SourcePeer != peerID { if svc.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer") return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
} }
return service, nil return svc, nil
} }
func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error { func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
var service *reverseproxy.Service var svc *service.Service
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil { if err != nil {
return err return err
} }
if service.Source != reverseproxy.SourceEphemeral { if svc.Source != service.SourceEphemeral {
return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose") return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose")
} }
if service.SourcePeer != peerID { if svc.SourcePeer != peerID {
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer") return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
} }
@@ -865,17 +854,68 @@ func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID,
peer = nil peer = nil
} }
meta := addPeerInfoToEventMeta(service.EventMeta(), peer) meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta) m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta)
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "") m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID) m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }
// deleteExpiredPeerService deletes an ephemeral service by ID after re-checking
// that it is still expired under a row lock. This prevents deleting a service
// that was renewed between the batch query and this delete, and ensures only one
// management instance processes the deletion
func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerID, serviceID string) error {
var svc *service.Service
deleted := false
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if svc.Source != service.SourceEphemeral || svc.SourcePeer != peerID {
return status.Errorf(status.PermissionDenied, "service does not match expected ephemeral owner")
}
if svc.Meta.LastRenewedAt != nil && time.Since(*svc.Meta.LastRenewedAt) <= exposeTTL {
return nil
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
deleted = true
return nil
})
if err != nil {
return err
}
if !deleted {
return nil
}
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
peer = nil
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any { func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any {
if peer == nil { if peer == nil {
return meta return meta

View File

@@ -10,18 +10,21 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -30,13 +33,13 @@ func TestInitializeServiceForCreate(t *testing.T) {
accountID := "test-account" accountID := "test-account"
t.Run("successful initialization without cluster deriver", func(t *testing.T) { t.Run("successful initialization without cluster deriver", func(t *testing.T) {
mgr := &managerImpl{ mgr := &Manager{
clusterDeriver: nil, clusterDeriver: nil,
} }
service := &reverseproxy.Service{ service := &rpservice.Service{
Domain: "example.com", Domain: "example.com",
Auth: reverseproxy.AuthConfig{}, Auth: rpservice.AuthConfig{},
} }
err := mgr.initializeServiceForCreate(ctx, accountID, service) err := mgr.initializeServiceForCreate(ctx, accountID, service)
@@ -50,12 +53,12 @@ func TestInitializeServiceForCreate(t *testing.T) {
}) })
t.Run("verifies session keys are different", func(t *testing.T) { t.Run("verifies session keys are different", func(t *testing.T) {
mgr := &managerImpl{ mgr := &Manager{
clusterDeriver: nil, clusterDeriver: nil,
} }
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}} service1 := &rpservice.Service{Domain: "test1.com", Auth: rpservice.AuthConfig{}}
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}} service2 := &rpservice.Service{Domain: "test2.com", Auth: rpservice.AuthConfig{}}
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1) err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2) err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
@@ -97,7 +100,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) { setupMock: func(ms *store.MockStore) {
ms.EXPECT(). ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com"). GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil) Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
}, },
expectedError: true, expectedError: true,
errorType: status.AlreadyExists, errorType: status.AlreadyExists,
@@ -109,7 +112,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) { setupMock: func(ms *store.MockStore) {
ms.EXPECT(). ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com"). GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil) Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
}, },
expectedError: false, expectedError: false,
}, },
@@ -120,7 +123,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) { setupMock: func(ms *store.MockStore) {
ms.EXPECT(). ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com"). GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil) Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
}, },
expectedError: true, expectedError: true,
errorType: status.AlreadyExists, errorType: status.AlreadyExists,
@@ -146,7 +149,7 @@ func TestCheckDomainAvailable(t *testing.T) {
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
tt.setupMock(mockStore) tt.setupMock(mockStore)
mgr := &managerImpl{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID) err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
if tt.expectedError { if tt.expectedError {
@@ -176,7 +179,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
GetServiceByDomain(ctx, accountID, ""). GetServiceByDomain(ctx, accountID, "").
Return(nil, status.Errorf(status.NotFound, "not found")) Return(nil, status.Errorf(status.NotFound, "not found"))
mgr := &managerImpl{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "") err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
assert.NoError(t, err) assert.NoError(t, err)
@@ -189,9 +192,9 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT(). mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "test.com"). GetServiceByDomain(ctx, accountID, "test.com").
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil) Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
mgr := &managerImpl{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "") err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
assert.Error(t, err) assert.Error(t, err)
@@ -209,7 +212,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
GetServiceByDomain(ctx, accountID, "nil.com"). GetServiceByDomain(ctx, accountID, "nil.com").
Return(nil, nil) Return(nil, nil)
mgr := &managerImpl{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "") err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
assert.NoError(t, err) assert.NoError(t, err)
@@ -225,10 +228,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{ service := &rpservice.Service{
ID: "service-123", ID: "service-123",
Domain: "new.com", Domain: "new.com",
Targets: []*reverseproxy.Target{}, Targets: []*rpservice.Target{},
} }
// Mock ExecuteInTransaction to execute the function immediately // Mock ExecuteInTransaction to execute the function immediately
@@ -247,7 +250,7 @@ func TestPersistNewService(t *testing.T) {
return fn(txMock) return fn(txMock)
}) })
mgr := &managerImpl{store: mockStore} mgr := &Manager{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service) err := mgr.persistNewService(ctx, accountID, service)
assert.NoError(t, err) assert.NoError(t, err)
@@ -258,10 +261,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{ service := &rpservice.Service{
ID: "service-123", ID: "service-123",
Domain: "existing.com", Domain: "existing.com",
Targets: []*reverseproxy.Target{}, Targets: []*rpservice.Target{},
} }
mockStore.EXPECT(). mockStore.EXPECT().
@@ -270,12 +273,12 @@ func TestPersistNewService(t *testing.T) {
txMock := store.NewMockStore(ctrl) txMock := store.NewMockStore(ctrl)
txMock.EXPECT(). txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "existing.com"). GetServiceByDomain(ctx, accountID, "existing.com").
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil) Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
return fn(txMock) return fn(txMock)
}) })
mgr := &managerImpl{store: mockStore} mgr := &Manager{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service) err := mgr.persistNewService(ctx, accountID, service)
require.Error(t, err) require.Error(t, err)
@@ -285,21 +288,21 @@ func TestPersistNewService(t *testing.T) {
}) })
} }
func TestPreserveExistingAuthSecrets(t *testing.T) { func TestPreserveExistingAuthSecrets(t *testing.T) {
mgr := &managerImpl{} mgr := &Manager{}
t.Run("preserve password when empty", func(t *testing.T) { t.Run("preserve password when empty", func(t *testing.T) {
existing := &reverseproxy.Service{ existing := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{ PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true, Enabled: true,
Password: "hashed-password", Password: "hashed-password",
}, },
}, },
} }
updated := &reverseproxy.Service{ updated := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{ PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true, Enabled: true,
Password: "", Password: "",
}, },
@@ -312,18 +315,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
}) })
t.Run("preserve pin when empty", func(t *testing.T) { t.Run("preserve pin when empty", func(t *testing.T) {
existing := &reverseproxy.Service{ existing := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{ PinAuth: &rpservice.PINAuthConfig{
Enabled: true, Enabled: true,
Pin: "hashed-pin", Pin: "hashed-pin",
}, },
}, },
} }
updated := &reverseproxy.Service{ updated := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{ PinAuth: &rpservice.PINAuthConfig{
Enabled: true, Enabled: true,
Pin: "", Pin: "",
}, },
@@ -336,18 +339,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
}) })
t.Run("do not preserve when password is provided", func(t *testing.T) { t.Run("do not preserve when password is provided", func(t *testing.T) {
existing := &reverseproxy.Service{ existing := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{ PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true, Enabled: true,
Password: "old-password", Password: "old-password",
}, },
}, },
} }
updated := &reverseproxy.Service{ updated := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{ PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true, Enabled: true,
Password: "new-password", Password: "new-password",
}, },
@@ -362,10 +365,10 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
} }
func TestPreserveServiceMetadata(t *testing.T) { func TestPreserveServiceMetadata(t *testing.T) {
mgr := &managerImpl{} mgr := &Manager{}
existing := &reverseproxy.Service{ existing := &rpservice.Service{
Meta: reverseproxy.ServiceMeta{ Meta: rpservice.Meta{
CertificateIssuedAt: func() *time.Time { t := time.Now(); return &t }(), CertificateIssuedAt: func() *time.Time { t := time.Now(); return &t }(),
Status: "active", Status: "active",
}, },
@@ -373,7 +376,7 @@ func TestPreserveServiceMetadata(t *testing.T) {
SessionPublicKey: "public-key", SessionPublicKey: "public-key",
} }
updated := &reverseproxy.Service{ updated := &rpservice.Service{
Domain: "updated.com", Domain: "updated.com",
} }
@@ -397,31 +400,32 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
IP: net.ParseIP("100.64.0.1"), IP: net.ParseIP("100.64.0.1"),
} }
newEphemeralService := func() *reverseproxy.Service { newEphemeralService := func() *rpservice.Service {
return &reverseproxy.Service{ return &rpservice.Service{
ID: serviceID, ID: serviceID,
AccountID: accountID, AccountID: accountID,
Name: "test-service", Name: "test-service",
Domain: "test.example.com", Domain: "test.example.com",
Source: reverseproxy.SourceEphemeral, Source: rpservice.SourceEphemeral,
SourcePeer: ownerPeerID, SourcePeer: ownerPeerID,
} }
} }
newPermanentService := func() *reverseproxy.Service { newPermanentService := func() *rpservice.Service {
return &reverseproxy.Service{ return &rpservice.Service{
ID: serviceID, ID: serviceID,
AccountID: accountID, AccountID: accountID,
Name: "api-service", Name: "api-service",
Domain: "api.example.com", Domain: "api.example.com",
Source: reverseproxy.SourcePermanent, Source: rpservice.SourcePermanent,
} }
} }
newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer { newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer {
t.Helper() t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Hour) tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil) require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(srv.Close) t.Cleanup(srv.Close)
return srv return srv
} }
@@ -455,10 +459,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID). GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil) Return(testPeer, nil)
mgr := &managerImpl{ mgr := &Manager{
store: mockStore, store: mockStore,
accountManager: mockAccountMgr, accountManager: mockAccountMgr,
proxyGRPCServer: newProxyServer(t), proxyController: func() proxy.Controller {
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
return c
}(),
} }
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed) err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed)
@@ -482,7 +490,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
return fn(txMock) return fn(txMock)
}) })
mgr := &managerImpl{ mgr := &Manager{
store: mockStore, store: mockStore,
} }
@@ -511,7 +519,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
return fn(txMock) return fn(txMock)
}) })
mgr := &managerImpl{ mgr := &Manager{
store: mockStore, store: mockStore,
} }
@@ -553,10 +561,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID). GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil) Return(testPeer, nil)
mgr := &managerImpl{ mgr := &Manager{
store: mockStore, store: mockStore,
accountManager: mockAccountMgr, accountManager: mockAccountMgr,
proxyGRPCServer: newProxyServer(t), proxyController: func() proxy.Controller {
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
return c
}(),
} }
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceExposeExpired) err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceExposeExpired)
@@ -593,10 +605,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID). GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil) Return(testPeer, nil)
mgr := &managerImpl{ mgr := &Manager{
store: mockStore, store: mockStore,
accountManager: mockAccountMgr, accountManager: mockAccountMgr,
proxyGRPCServer: newProxyServer(t), proxyController: func() proxy.Controller {
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
return c
}(),
} }
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed) err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed)
@@ -609,19 +625,6 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
}) })
} }
// noopExtraSettings is a minimal extra_settings.Manager for tests without external integrations.
type noopExtraSettings struct{}
func (n *noopExtraSettings) GetExtraSettings(_ context.Context, _ string) (*types.ExtraSettings, error) {
return &types.ExtraSettings{}, nil
}
func (n *noopExtraSettings) UpdateExtraSettings(_ context.Context, _, _ string, _ *types.ExtraSettings) (bool, error) {
return false, nil
}
var _ extra_settings.Manager = (*noopExtraSettings)(nil)
// testClusterDeriver is a minimal ClusterDeriver that returns a fixed domain list. // testClusterDeriver is a minimal ClusterDeriver that returns a fixed domain list.
type testClusterDeriver struct { type testClusterDeriver struct {
domains []string domains []string
@@ -643,7 +646,7 @@ const (
) )
// setupIntegrationTest creates a real SQLite store with seeded test data for integration tests. // setupIntegrationTest creates a real SQLite store with seeded test data for integration tests.
func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) { func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
t.Helper() t.Helper()
ctx := context.Background() ctx := context.Background()
@@ -691,35 +694,33 @@ func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) {
require.NoError(t, err) require.NoError(t, err)
permsMgr := permissions.NewManager(testStore) permsMgr := permissions.NewManager(testStore)
usersMgr := users.NewManager(testStore)
settingsMgr := settings.NewManager(testStore, usersMgr, &noopExtraSettings{}, permsMgr, settings.IdpConfig{})
var storedEvents []activity.Activity
accountMgr := &mock_server.MockAccountManager{ accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
storedEvents = append(storedEvents, activityID.(activity.Activity))
},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) { GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID) return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
}, },
} }
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Hour) tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil) require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close) t.Cleanup(proxySrv.Close)
mgr := &managerImpl{ proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
mgr := &Manager{
store: testStore, store: testStore,
accountManager: accountMgr, accountManager: accountMgr,
permissionsManager: permsMgr, permissionsManager: permsMgr,
settingsManager: settingsMgr, proxyController: proxyController,
proxyGRPCServer: proxySrv,
clusterDeriver: &testClusterDeriver{ clusterDeriver: &testClusterDeriver{
domains: []string{"test.netbird.io"}, domains: []string{"test.netbird.io"},
}, },
} }
mgr.exposeTracker = &exposeTracker{manager: mgr} mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr, testStore return mgr, testStore
} }
@@ -788,7 +789,7 @@ func Test_validateExposePermission(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountSettings(gomock.Any(), gomock.Any(), testAccountID).Return(nil, errors.New("store error")) mockStore.EXPECT().GetAccountSettings(gomock.Any(), gomock.Any(), testAccountID).Return(nil, errors.New("store error"))
mgr := &managerImpl{store: mockStore} mgr := &Manager{store: mockStore}
err := mgr.validateExposePermission(ctx, testAccountID, testPeerID) err := mgr.validateExposePermission(ctx, testAccountID, testPeerID)
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "get account settings") assert.Contains(t, err.Error(), "get account settings")
@@ -801,7 +802,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
t.Run("creates service with random domain", func(t *testing.T) { t.Run("creates service with random domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t) mgr, testStore := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{ req := &rpservice.ExposeServiceRequest{
Port: 8080, Port: 8080,
Protocol: "http", Protocol: "http",
} }
@@ -816,7 +817,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, resp.Domain, persisted.Domain) assert.Equal(t, resp.Domain, persisted.Domain)
assert.Equal(t, reverseproxy.SourceEphemeral, persisted.Source, "source should be ephemeral") assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
assert.Equal(t, testPeerID, persisted.SourcePeer, "source peer should be set") assert.Equal(t, testPeerID, persisted.SourcePeer, "source peer should be set")
assert.NotNil(t, persisted.Meta.LastRenewedAt, "last renewed should be set") assert.NotNil(t, persisted.Meta.LastRenewedAt, "last renewed should be set")
}) })
@@ -824,7 +825,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
t.Run("creates service with custom domain", func(t *testing.T) { t.Run("creates service with custom domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t) mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{ req := &rpservice.ExposeServiceRequest{
Port: 80, Port: 80,
Protocol: "http", Protocol: "http",
Domain: "example.com", Domain: "example.com",
@@ -845,7 +846,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
err = testStore.SaveAccountSettings(ctx, testAccountID, s) err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err) require.NoError(t, err)
req := &reverseproxy.ExposeServiceRequest{ req := &rpservice.ExposeServiceRequest{
Port: 8080, Port: 8080,
Protocol: "http", Protocol: "http",
} }
@@ -858,7 +859,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
t.Run("validates request fields", func(t *testing.T) { t.Run("validates request fields", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t) mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{ req := &rpservice.ExposeServiceRequest{
Port: 0, Port: 0,
Protocol: "http", Protocol: "http",
} }
@@ -872,67 +873,67 @@ func TestCreateServiceFromPeer(t *testing.T) {
func TestExposeServiceRequestValidate(t *testing.T) { func TestExposeServiceRequestValidate(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
req reverseproxy.ExposeServiceRequest req rpservice.ExposeServiceRequest
wantErr string wantErr string
}{ }{
{ {
name: "valid http request", name: "valid http request",
req: reverseproxy.ExposeServiceRequest{Port: 8080, Protocol: "http"}, req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"},
wantErr: "", wantErr: "",
}, },
{ {
name: "valid https request with pin", name: "valid https request with pin",
req: reverseproxy.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"}, req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
wantErr: "", wantErr: "",
}, },
{ {
name: "port zero rejected", name: "port zero rejected",
req: reverseproxy.ExposeServiceRequest{Port: 0, Protocol: "http"}, req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"},
wantErr: "port must be between 1 and 65535", wantErr: "port must be between 1 and 65535",
}, },
{ {
name: "negative port rejected", name: "negative port rejected",
req: reverseproxy.ExposeServiceRequest{Port: -1, Protocol: "http"}, req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"},
wantErr: "port must be between 1 and 65535", wantErr: "port must be between 1 and 65535",
}, },
{ {
name: "port above 65535 rejected", name: "port above 65535 rejected",
req: reverseproxy.ExposeServiceRequest{Port: 65536, Protocol: "http"}, req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"},
wantErr: "port must be between 1 and 65535", wantErr: "port must be between 1 and 65535",
}, },
{ {
name: "unsupported protocol", name: "unsupported protocol",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "tcp"}, req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
wantErr: "unsupported protocol", wantErr: "unsupported protocol",
}, },
{ {
name: "invalid pin format", name: "invalid pin format",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"}, req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
wantErr: "invalid pin", wantErr: "invalid pin",
}, },
{ {
name: "pin too short", name: "pin too short",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"}, req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
wantErr: "invalid pin", wantErr: "invalid pin",
}, },
{ {
name: "valid 6-digit pin", name: "valid 6-digit pin",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"}, req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
wantErr: "", wantErr: "",
}, },
{ {
name: "empty user group name", name: "empty user group name",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}}, req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
wantErr: "user group name cannot be empty", wantErr: "user group name cannot be empty",
}, },
{ {
name: "invalid name prefix", name: "invalid name prefix",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"}, req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
wantErr: "invalid name prefix", wantErr: "invalid name prefix",
}, },
{ {
name: "valid name prefix", name: "valid name prefix",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"}, req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
wantErr: "", wantErr: "",
}, },
} }
@@ -950,7 +951,7 @@ func TestExposeServiceRequestValidate(t *testing.T) {
} }
t.Run("nil receiver", func(t *testing.T) { t.Run("nil receiver", func(t *testing.T) {
var req *reverseproxy.ExposeServiceRequest var req *rpservice.ExposeServiceRequest
err := req.Validate() err := req.Validate()
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "request cannot be nil") assert.Contains(t, err.Error(), "request cannot be nil")
@@ -964,7 +965,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
mgr, testStore := setupIntegrationTest(t) mgr, testStore := setupIntegrationTest(t)
// First create a service // First create a service
req := &reverseproxy.ExposeServiceRequest{ req := &rpservice.ExposeServiceRequest{
Port: 8080, Port: 8080,
Protocol: "http", Protocol: "http",
} }
@@ -983,7 +984,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
t.Run("expire uses correct activity", func(t *testing.T) { t.Run("expire uses correct activity", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t) mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{ req := &rpservice.ExposeServiceRequest{
Port: 8080, Port: 8080,
Protocol: "http", Protocol: "http",
} }
@@ -1001,7 +1002,7 @@ func TestStopServiceFromPeer(t *testing.T) {
t.Run("stops service by domain", func(t *testing.T) { t.Run("stops service by domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t) mgr, testStore := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{ req := &rpservice.ExposeServiceRequest{
Port: 8080, Port: 8080,
Protocol: "http", Protocol: "http",
} }
@@ -1016,53 +1017,59 @@ func TestStopServiceFromPeer(t *testing.T) {
}) })
} }
func TestDeleteService_UntracksEphemeralExpose(t *testing.T) { func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
ctx := context.Background() ctx := context.Background()
mgr, _ := setupIntegrationTest(t) mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080, Port: 8080,
Protocol: "http", Protocol: "http",
}) })
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be tracked after create")
// Look up the service by domain to get its store ID count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
svc, err := mgr.store.GetServiceByDomain(ctx, testAccountID, resp.Domain) require.NoError(t, err)
assert.Equal(t, int64(1), count, "one ephemeral service should exist after create")
svc, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err) require.NoError(t, err)
// Delete via the API path (user-initiated)
err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID) err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be untracked after API delete") count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete")
// A new expose should succeed (not blocked by stale tracking) _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 9090, Port: 9090,
Protocol: "http", Protocol: "http",
}) })
assert.NoError(t, err, "new expose should succeed after API delete cleared tracking") assert.NoError(t, err, "new expose should succeed after API delete")
} }
func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) { func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) {
ctx := context.Background() ctx := context.Background()
mgr, _ := setupIntegrationTest(t) mgr, _ := setupIntegrationTest(t)
for i := range 3 { for i := range 3 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i, Port: 8080 + i,
Protocol: "http", Protocol: "http",
}) })
require.NoError(t, err) require.NoError(t, err)
} }
assert.Equal(t, 3, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be tracked") count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(3), count, "all ephemeral services should exist")
err := mgr.DeleteAllServices(ctx, testAccountID, testUserID) err = mgr.DeleteAllServices(ctx, testAccountID, testUserID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be untracked after DeleteAllServices") count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count, "all ephemeral services should be deleted after DeleteAllServices")
} }
func TestRenewServiceFromPeer(t *testing.T) { func TestRenewServiceFromPeer(t *testing.T) {
@@ -1071,7 +1078,7 @@ func TestRenewServiceFromPeer(t *testing.T) {
t.Run("renews tracked expose", func(t *testing.T) { t.Run("renews tracked expose", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t) mgr, _ := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080, Port: 8080,
Protocol: "http", Protocol: "http",
}) })
@@ -1112,3 +1119,74 @@ func TestGetGroupIDsFromNames(t *testing.T) {
assert.Contains(t, err.Error(), "no group names provided") assert.Contains(t, err.Error(), "no group names provided")
}) })
} }
func TestDeleteService_DeletesTargets(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
userID := "test-user"
sqlStore, err := store.NewStore(ctx, types.SqliteStoreEngine, t.TempDir(), nil, false)
require.NoError(t, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockPerms := permissions.NewMockManager(ctrl)
mockAcct := account.NewMockManager(ctrl)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
mgr := &Manager{
store: sqlStore,
permissionsManager: mockPerms,
accountManager: mockAcct,
proxyController: proxyController,
}
service := &rpservice.Service{
ID: "service-1",
AccountID: accountID,
Domain: "test.example.com",
ProxyCluster: "cluster1",
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-1"},
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-2"},
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-3"},
},
}
err = sqlStore.CreateService(ctx, service)
require.NoError(t, err)
retrievedService, err := sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID)
require.NoError(t, err)
require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion")
mockPerms.EXPECT().
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
Return(true, nil)
mockAcct.EXPECT().
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
mockAcct.EXPECT().
UpdateAccountPeers(ctx, accountID)
err = mgr.DeleteService(ctx, accountID, userID, service.ID)
require.NoError(t, err)
_, err = sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID)
require.Error(t, err)
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.NotFound, s.Type())
targets, err := sqlStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, accountID, service.ID)
require.NoError(t, err)
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
}

View File

@@ -1,4 +1,4 @@
package reverseproxy package service
import ( import (
"crypto/rand" "crypto/rand"
@@ -14,6 +14,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id" "github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/util/crypt" "github.com/netbirdio/netbird/util/crypt"
@@ -29,15 +30,15 @@ const (
Delete Operation = "delete" Delete Operation = "delete"
) )
type ProxyStatus string type Status string
const ( const (
StatusPending ProxyStatus = "pending" StatusPending Status = "pending"
StatusActive ProxyStatus = "active" StatusActive Status = "active"
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created" StatusTunnelNotCreated Status = "tunnel_not_created"
StatusCertificatePending ProxyStatus = "certificate_pending" StatusCertificatePending Status = "certificate_pending"
StatusCertificateFailed ProxyStatus = "certificate_failed" StatusCertificateFailed Status = "certificate_failed"
StatusError ProxyStatus = "error" StatusError Status = "error"
TargetTypePeer = "peer" TargetTypePeer = "peer"
TargetTypeHost = "host" TargetTypeHost = "host"
@@ -111,14 +112,7 @@ func (a *AuthConfig) ClearSecrets() {
} }
} }
type OIDCValidationConfig struct { type Meta struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
type ServiceMeta struct {
CreatedAt time.Time CreatedAt time.Time
CertificateIssuedAt *time.Time CertificateIssuedAt *time.Time
Status string Status string
@@ -135,12 +129,12 @@ type Service struct {
Enabled bool Enabled bool
PassHostHeader bool PassHostHeader bool
RewriteRedirects bool RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"` Auth AuthConfig `gorm:"serializer:json"`
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"` Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"` SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"` SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent'"` Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
SourcePeer string SourcePeer string `gorm:"index:idx_service_source_peer"`
} }
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service { func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
@@ -165,7 +159,7 @@ func NewService(accountID, name, domain, proxyCluster string, targets []*Target,
// only be called during initial creation, not for updates. // only be called during initial creation, not for updates.
func (s *Service) InitNewRecord() { func (s *Service) InitNewRecord() {
s.ID = xid.New().String() s.ID = xid.New().String()
s.Meta = ServiceMeta{ s.Meta = Meta{
CreatedAt: time.Now(), CreatedAt: time.Now(),
Status: string(StatusPending), Status: string(StatusPending),
} }
@@ -239,7 +233,7 @@ func (s *Service) ToAPIResponse() *api.Service {
return resp return resp
} }
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping { func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets)) pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets { for _, target := range s.Targets {
if !target.Enabled { if !target.Enabled {

View File

@@ -1,4 +1,4 @@
package reverseproxy package service
import ( import (
"errors" "errors"
@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id" "github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -109,7 +110,7 @@ func TestIsDefaultPort(t *testing.T) {
} }
func TestToProtoMapping_PortInTargetURL(t *testing.T) { func TestToProtoMapping_PortInTargetURL(t *testing.T) {
oidcConfig := OIDCValidationConfig{} oidcConfig := proxy.OIDCValidationConfig{}
tests := []struct { tests := []struct {
name string name string
@@ -202,7 +203,7 @@ func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true}, {TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
}, },
} }
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{}) pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1) require.Len(t, pm.Path, 1)
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target) assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
} }
@@ -219,7 +220,7 @@ func TestToProtoMapping_OperationTypes(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(string(tt.op), func(t *testing.T) { t.Run(string(tt.op), func(t *testing.T) {
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{}) pm := rp.ToProtoMapping(tt.op, "", proxy.OIDCValidationConfig{})
assert.Equal(t, tt.want, pm.Type) assert.Equal(t, tt.want, pm.Type)
}) })
} }

View File

@@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler { return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
if err != nil { if err != nil {
log.Fatalf("failed to create API handler: %v", err) log.Fatalf("failed to create API handler: %v", err)
} }
@@ -134,7 +134,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if s.Config.HttpConfig.LetsEncryptDomain != "" { if s.Config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil { if err != nil {
log.Fatalf("failed to create certificate manager: %v", err) log.Fatalf("failed to create certificate service: %v", err)
} }
transportCredentials := credentials.NewTLS(certManager.TLSConfig()) transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials)) gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
@@ -152,10 +152,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if err != nil { if err != nil {
log.Fatalf("failed to create management server: %v", err) log.Fatalf("failed to create management server: %v", err)
} }
reverseProxyMgr := s.ReverseProxyManager() serviceMgr := s.ServiceManager()
srv.SetReverseProxyManager(reverseProxyMgr) srv.SetReverseProxyManager(serviceMgr)
if reverseProxyMgr != nil { if serviceMgr != nil {
reverseProxyMgr.StartExposeReaper(context.Background()) serviceMgr.StartExposeReaper(context.Background())
} }
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv) mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
@@ -168,9 +168,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager()) proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ReverseProxyManager()) proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())
}) })
return proxyService return proxyService
}) })
@@ -193,7 +194,10 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
return Create(s, func() *nbgrpc.OneTimeTokenStore { return Create(s, func() *nbgrpc.OneTimeTokenStore {
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create proxy token store: %v", err)
}
log.Info("One-time token store initialized for proxy authentication") log.Info("One-time token store initialized for proxy authentication")
return tokenStore return tokenStore
}) })

View File

@@ -6,6 +6,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -106,6 +108,16 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
}) })
} }
func (s *BaseServer) ServiceProxyController() proxy.Controller {
return Create(s, func() proxy.Controller {
controller, err := proxymanager.NewGRPCController(s.ReverseProxyGRPCServer(), s.Metrics().GetMeter())
if err != nil {
log.Fatalf("failed to create service proxy controller: %v", err)
}
return controller
})
}
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
return Create(s, func() *server.AccountRequestBuffer { return Create(s, func() *server.AccountRequestBuffer {
return server.NewAccountRequestBuffer(context.Background(), s.Store()) return server.NewAccountRequestBuffer(context.Background(), s.Store())

View File

@@ -8,9 +8,11 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
@@ -99,11 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager { return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
if err != nil { if err != nil {
log.Fatalf("failed to create account manager: %v", err) log.Fatalf("failed to create account service: %v", err)
} }
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
accountManager.SetServiceManager(s.ReverseProxyManager()) accountManager.SetServiceManager(s.ServiceManager())
}) })
return accountManager return accountManager
@@ -114,28 +116,28 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager { return Create(s, func() idp.Manager {
var idpManager idp.Manager var idpManager idp.Manager
var err error var err error
// Use embedded IdP manager if embedded Dex is configured and enabled. // Use embedded IdP service if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured. // Legacy IdpManager won't be used anymore even if configured.
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled { if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics()) idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil { if err != nil {
log.Fatalf("failed to create embedded IDP manager: %v", err) log.Fatalf("failed to create embedded IDP service: %v", err)
} }
return idpManager return idpManager
} }
// Fall back to external IdP manager // Fall back to external IdP service
if s.Config.IdpManagerConfig != nil { if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil { if err != nil {
log.Fatalf("failed to create IDP manager: %v", err) log.Fatalf("failed to create IDP service: %v", err)
} }
} }
return idpManager return idpManager
}) })
} }
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil // OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider { func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled { if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
return nil return nil
@@ -162,7 +164,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
func (s *BaseServer) ResourcesManager() resources.Manager { func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager { return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager()) return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
}) })
} }
@@ -190,15 +192,25 @@ func (s *BaseServer) RecordsManager() records.Manager {
}) })
} }
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager { func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() reverseproxy.Manager { return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.SettingsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager()) return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
})
}
func (s *BaseServer) ProxyManager() proxy.Manager {
return Create(s, func() proxy.Manager {
manager, err := proxymanager.NewManager(s.Store(), s.Metrics().GetMeter())
if err != nil {
log.Fatalf("failed to create proxy manager: %v", err)
}
return manager
}) })
} }
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager { return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager()) m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager())
return &m return &m
}) })
} }

View File

@@ -157,7 +157,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
// Eagerly create the gRPC server so that all AfterInit hooks are registered // Eagerly create the gRPC server so that all AfterInit hooks are registered
// before we iterate them. Lazy creation after the loop would miss hooks // before we iterate them. Lazy creation after the loop would miss hooks
// registered during GRPCServer() construction (e.g., SetProxyManager). // registered during GRPCServer() construction (e.g., SetServiceManager).
s.GRPCServer() s.GRPCServer()
for _, fn := range s.afterInit { for _, fn := range s.afterInit {

View File

@@ -10,7 +10,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -39,7 +39,7 @@ func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage)
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
} }
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &reverseproxy.ExposeServiceRequest{ created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{
NamePrefix: exposeReq.NamePrefix, NamePrefix: exposeReq.NamePrefix,
Port: int(exposeReq.Port), Port: int(exposeReq.Port),
Protocol: exposeProtocolToString(exposeReq.Protocol), Protocol: exposeProtocolToString(exposeReq.Protocol),
@@ -167,14 +167,14 @@ func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key
return accountID, peer, nil return accountID, peer, nil
} }
func (s *Server) getReverseProxyManager() reverseproxy.Manager { func (s *Server) getReverseProxyManager() rpservice.Manager {
s.reverseProxyMu.RLock() s.reverseProxyMu.RLock()
defer s.reverseProxyMu.RUnlock() defer s.reverseProxyMu.RUnlock()
return s.reverseProxyManager return s.reverseProxyManager
} }
// SetReverseProxyManager sets the reverse proxy manager on the server. // SetReverseProxyManager sets the reverse proxy manager on the server.
func (s *Server) SetReverseProxyManager(mgr reverseproxy.Manager) { func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) {
s.reverseProxyMu.Lock() s.reverseProxyMu.Lock()
defer s.reverseProxyMu.Unlock() defer s.reverseProxyMu.Unlock()
s.reverseProxyManager = mgr s.reverseProxyManager = mgr

View File

@@ -1,28 +1,23 @@
package grpc package grpc
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"crypto/subtle" "crypto/subtle"
"encoding/base64" "encoding/base64"
"encoding/hex"
"encoding/json"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
) )
// OneTimeTokenStore manages short-lived, single-use authentication tokens
// for proxy-to-management RPC authentication. Tokens are generated when
// a service is created and must be used exactly once by the proxy
// to authenticate a subsequent RPC call.
type OneTimeTokenStore struct {
tokens map[string]*tokenMetadata
mu sync.RWMutex
cleanup *time.Ticker
cleanupDone chan struct{}
}
// tokenMetadata stores information about a one-time token
type tokenMetadata struct { type tokenMetadata struct {
ServiceID string ServiceID string
AccountID string AccountID string
@@ -30,20 +25,24 @@ type tokenMetadata struct {
CreatedAt time.Time CreatedAt time.Time
} }
// NewOneTimeTokenStore creates a new token store with automatic cleanup // OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
// of expired tokens. The cleanupInterval determines how often expired // Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
// tokens are removed from memory. type OneTimeTokenStore struct {
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore { cache *cache.Cache[string]
store := &OneTimeTokenStore{ ctx context.Context
tokens: make(map[string]*tokenMetadata), }
cleanup: time.NewTicker(cleanupInterval),
cleanupDone: make(chan struct{}), // NewOneTimeTokenStore creates a token store with automatic backend selection
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
} }
// Start background cleanup goroutine return &OneTimeTokenStore{
go store.cleanupExpired() cache: cache.New[string](cacheStore),
ctx: ctx,
return store }, nil
} }
// GenerateToken creates a new cryptographically secure one-time token // GenerateToken creates a new cryptographically secure one-time token
@@ -52,25 +51,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
// //
// Returns the generated token string or an error if random generation fails. // Returns the generated token string or an error if random generation fails.
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) { func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
// Generate 32 bytes (256 bits) of cryptographically secure random data
randomBytes := make([]byte, 32) randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil { if _, err := rand.Read(randomBytes); err != nil {
return "", fmt.Errorf("failed to generate random token: %w", err) return "", fmt.Errorf("failed to generate random token: %w", err)
} }
// Encode as URL-safe base64 for easy transmission in gRPC
token := base64.URLEncoding.EncodeToString(randomBytes) token := base64.URLEncoding.EncodeToString(randomBytes)
hashedToken := hashToken(token)
s.mu.Lock() metadata := &tokenMetadata{
defer s.mu.Unlock()
s.tokens[token] = &tokenMetadata{
ServiceID: serviceID, ServiceID: serviceID,
AccountID: accountID, AccountID: accountID,
ExpiresAt: time.Now().Add(ttl), ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
metadataJSON, err := json.Marshal(metadata)
if err != nil {
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
}
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
return "", fmt.Errorf("failed to store token: %w", err)
}
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)", log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
serviceID, accountID, ttl) serviceID, accountID, ttl)
@@ -88,80 +92,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
// - Account ID doesn't match // - Account ID doesn't match
// - Reverse proxy ID doesn't match // - Reverse proxy ID doesn't match
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error { func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
s.mu.Lock() hashedToken := hashToken(token)
defer s.mu.Unlock()
metadata, exists := s.tokens[token] metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
if !exists { if err != nil {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
serviceID, accountID)
return fmt.Errorf("invalid token") return fmt.Errorf("invalid token")
} }
// Check expiration metadata := &tokenMetadata{}
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
return fmt.Errorf("invalid token metadata")
}
if time.Now().After(metadata.ExpiresAt) { if time.Now().After(metadata.ExpiresAt) {
delete(s.tokens, token) log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
serviceID, accountID)
return fmt.Errorf("token expired") return fmt.Errorf("token expired")
} }
// Validate account ID using constant-time comparison (prevents timing attacks)
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 { if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
metadata.AccountID, accountID)
return fmt.Errorf("account ID mismatch") return fmt.Errorf("account ID mismatch")
} }
// Validate service ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 { if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
metadata.ServiceID, serviceID)
return fmt.Errorf("service ID mismatch") return fmt.Errorf("service ID mismatch")
} }
// Delete token immediately to enforce single-use if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
delete(s.tokens, token) log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
}
log.Infof("Token validated and consumed for proxy %s in account %s", log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
serviceID, accountID)
return nil return nil
} }
// cleanupExpired removes expired tokens in the background to prevent memory leaks func hashToken(token string) string {
func (s *OneTimeTokenStore) cleanupExpired() { hash := sha256.Sum256([]byte(token))
for { return hex.EncodeToString(hash[:])
select {
case <-s.cleanup.C:
s.mu.Lock()
now := time.Now()
removed := 0
for token, metadata := range s.tokens {
if now.After(metadata.ExpiresAt) {
delete(s.tokens, token)
removed++
}
}
if removed > 0 {
log.Debugf("Cleaned up %d expired one-time tokens", removed)
}
s.mu.Unlock()
case <-s.cleanupDone:
return
}
}
}
// Close stops the cleanup goroutine and releases resources
func (s *OneTimeTokenStore) Close() {
s.cleanup.Stop()
close(s.cleanupDone)
}
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
func (s *OneTimeTokenStore) GetTokenCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.tokens)
} }

View File

@@ -24,8 +24,9 @@ import (
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
@@ -58,14 +59,17 @@ type ProxyServiceServer struct {
// Map of connected proxies: proxy_id -> proxy connection // Map of connected proxies: proxy_id -> proxy connection
connectedProxies sync.Map connectedProxies sync.Map
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
// Manager for access logs // Manager for access logs
accessLogManager accesslogs.Manager accessLogManager accesslogs.Manager
// Manager for reverse proxy operations // Manager for reverse proxy operations
reverseProxyManager reverseproxy.Manager serviceManager rpservice.Manager
// ProxyController for service updates and cluster management
proxyController proxy.Controller
// Manager for proxy connections
proxyManager proxy.Manager
// Manager for peers // Manager for peers
peersManager peers.Manager peersManager peers.Manager
@@ -104,7 +108,7 @@ type proxyConnection struct {
} }
// NewProxyServiceServer creates a new proxy service server. // NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer { func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{ s := &ProxyServiceServer{
accessLogManager: accessLogMgr, accessLogManager: accessLogMgr,
@@ -112,9 +116,11 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
tokenStore: tokenStore, tokenStore: tokenStore,
peersManager: peersManager, peersManager: peersManager,
usersManager: usersManager, usersManager: usersManager,
proxyManager: proxyMgr,
pkceCleanupCancel: cancel, pkceCleanupCancel: cancel,
} }
go s.cleanupPKCEVerifiers(ctx) go s.cleanupPKCEVerifiers(ctx)
go s.cleanupStaleProxies(ctx)
return s return s
} }
@@ -138,13 +144,33 @@ func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
} }
} }
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil {
log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err)
}
}
}
}
// Close stops background goroutines. // Close stops background goroutines.
func (s *ProxyServiceServer) Close() { func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel() s.pkceCleanupCancel()
} }
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) { func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
s.reverseProxyManager = manager s.serviceManager = manager
}
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
s.proxyController = proxyController
} }
// GetMappingUpdate handles the control stream with proxy clients // GetMappingUpdate handles the control stream with proxy clients
@@ -179,7 +205,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
} }
s.connectedProxies.Store(proxyID, conn) s.connectedProxies.Store(proxyID, conn)
s.addToCluster(conn.address, proxyID) if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
}
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"proxy_id": proxyID, "proxy_id": proxyID,
"address": proxyAddress, "address": proxyAddress,
@@ -187,8 +221,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
"total_proxies": len(s.GetConnectedProxies()), "total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster") }).Info("Proxy registered in cluster")
defer func() { defer func() {
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
s.connectedProxies.Delete(proxyID) s.connectedProxies.Delete(proxyID)
s.removeFromCluster(conn.address, proxyID) if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
cancel() cancel()
log.Infof("Proxy %s disconnected", proxyID) log.Infof("Proxy %s disconnected", proxyID)
}() }()
@@ -200,6 +241,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
errChan := make(chan error, 2) errChan := make(chan error, 2)
go s.sender(conn, errChan) go s.sender(conn, errChan)
// Start heartbeat goroutine
go s.heartbeat(connCtx, proxyID)
select { select {
case err := <-errChan: case err := <-errChan:
return fmt.Errorf("send update to proxy %s: %w", proxyID, err) return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
@@ -208,10 +252,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
} }
} }
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
}
case <-ctx.Done():
return
}
}
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy. // sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only services matching the proxy's cluster address are sent. // Only services matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
services, err := s.reverseProxyManager.GetGlobalServices(ctx) services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil { if err != nil {
return fmt.Errorf("get services from store: %w", err) return fmt.Errorf("get services from store: %w", err)
} }
@@ -220,7 +281,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return fmt.Errorf("proxy address is invalid") return fmt.Errorf("proxy address is invalid")
} }
var filtered []*reverseproxy.Service var filtered []*rpservice.Service
for _, service := range services { for _, service := range services {
if !service.Enabled { if !service.Enabled {
continue continue
@@ -255,7 +316,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{ Mapping: []*proto.ProxyMapping{
service.ToProtoMapping( service.ToProtoMapping(
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy. rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
token, token,
s.GetOIDCValidationConfig(), s.GetOIDCValidationConfig(),
), ),
@@ -389,61 +450,47 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
return urls return urls
} }
// addToCluster registers a proxy in a cluster.
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
}
// removeFromCluster removes a proxy from a cluster.
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
}
}
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster. // SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility). // If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
// For create/update operations a unique one-time auth token is generated per // For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management. // proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.GetMappingUpdateResponse, clusterAddr string) { func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
updateResponse := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{update},
}
if clusterAddr == "" { if clusterAddr == "" {
s.SendServiceUpdate(update) s.SendServiceUpdate(updateResponse)
return return
} }
proxySet, ok := s.clusterProxies.Load(clusterAddr) if s.proxyController == nil {
if !ok { log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr)
log.Debugf("No proxies connected for cluster %s", clusterAddr) return
}
proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr)
if len(proxyIDs) == 0 {
log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr)
return return
} }
log.Debugf("Sending service update to cluster %s", clusterAddr) log.Debugf("Sending service update to cluster %s", clusterAddr)
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool { for _, proxyID := range proxyIDs {
proxyID := key.(string)
if connVal, ok := s.connectedProxies.Load(proxyID); ok { if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection) conn := connVal.(*proxyConnection)
msg := s.perProxyMessage(update, proxyID) msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil { if msg == nil {
return true continue
} }
select { select {
case conn.sendChan <- msg: case conn.sendChan <- msg:
log.Debugf("Sent service update to proxy %s in cluster %s", proxyID, clusterAddr) log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default: default:
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
} }
} }
return true }
})
} }
// perProxyMessage returns a copy of update with a fresh one-time token for // perProxyMessage returns a copy of update with a fresh one-time token for
@@ -490,35 +537,8 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
} }
} }
// GetAvailableClusters returns information about all connected proxy clusters.
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
clusterCounts := make(map[string]int)
s.clusterProxies.Range(func(key, value interface{}) bool {
clusterAddr := key.(string)
proxySet := value.(*sync.Map)
count := 0
proxySet.Range(func(_, _ interface{}) bool {
count++
return true
})
if count > 0 {
clusterCounts[clusterAddr] = count
}
return true
})
clusters := make([]ClusterInfo, 0, len(clusterCounts))
for addr, count := range clusterCounts {
clusters = append(clusters, ClusterInfo{
Address: addr,
ConnectedProxies: count,
})
}
return clusters
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err) log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
@@ -537,7 +557,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
}, nil }, nil
} }
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) { func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) {
switch v := req.GetRequest().(type) { switch v := req.GetRequest().(type) {
case *proto.AuthenticateRequest_Pin: case *proto.AuthenticateRequest_Pin:
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth) return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
@@ -548,7 +568,7 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
} }
} }
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) { func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled { if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID) log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
return false, "", "" return false, "", ""
@@ -562,7 +582,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri
return true, "pin-user", proxyauth.MethodPIN return true, "pin-user", proxyauth.MethodPIN
} }
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) { func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled { if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID) log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
return false, "", "" return false, "", ""
@@ -584,7 +604,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
} }
} }
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) { func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
if !authenticated || service.SessionPrivateKey == "" { if !authenticated || service.SessionPrivateKey == "" {
return "", nil return "", nil
} }
@@ -624,7 +644,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
} }
if certificateIssued { if certificateIssued {
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil { if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp") log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err) return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
} }
@@ -636,7 +656,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
internalStatus := protoStatusToInternal(protoStatus) internalStatus := protoStatusToInternal(protoStatus)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to update service status") log.WithContext(ctx).WithError(err).Error("failed to update service status")
return nil, status.Errorf(codes.Internal, "update service status: %v", err) return nil, status.Errorf(codes.Internal, "update service status: %v", err)
} }
@@ -651,22 +671,22 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
} }
// protoStatusToInternal maps proto status to internal status // protoStatusToInternal maps proto status to internal status
func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus { func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
switch protoStatus { switch protoStatus {
case proto.ProxyStatus_PROXY_STATUS_PENDING: case proto.ProxyStatus_PROXY_STATUS_PENDING:
return reverseproxy.StatusPending return rpservice.StatusPending
case proto.ProxyStatus_PROXY_STATUS_ACTIVE: case proto.ProxyStatus_PROXY_STATUS_ACTIVE:
return reverseproxy.StatusActive return rpservice.StatusActive
case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED: case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED:
return reverseproxy.StatusTunnelNotCreated return rpservice.StatusTunnelNotCreated
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING: case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING:
return reverseproxy.StatusCertificatePending return rpservice.StatusCertificatePending
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED: case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED:
return reverseproxy.StatusCertificateFailed return rpservice.StatusCertificateFailed
case proto.ProxyStatus_PROXY_STATUS_ERROR: case proto.ProxyStatus_PROXY_STATUS_ERROR:
return reverseproxy.StatusError return rpservice.StatusError
default: default:
return reverseproxy.StatusError return rpservice.StatusError
} }
} }
@@ -731,7 +751,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
} }
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection. // Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId()) services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account services: %v", err) log.WithContext(ctx).Errorf("failed to get account services: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
@@ -794,8 +814,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig {
// GetOIDCValidationConfig returns the OIDC configuration for token validation // GetOIDCValidationConfig returns the OIDC configuration for token validation
// in the format needed by ToProtoMapping. // in the format needed by ToProtoMapping.
func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig { func (s *ProxyServiceServer) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return reverseproxy.OIDCValidationConfig{ return proxy.OIDCValidationConfig{
Issuer: s.oidcConfig.Issuer, Issuer: s.oidcConfig.Issuer,
Audiences: []string{s.oidcConfig.Audience}, Audiences: []string{s.oidcConfig.Audience},
KeysLocation: s.oidcConfig.KeysLocation, KeysLocation: s.oidcConfig.KeysLocation,
@@ -854,12 +874,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user. // GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key // Find the service by domain to get its signing key
services, err := s.reverseProxyManager.GetGlobalServices(ctx) services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil { if err != nil {
return "", fmt.Errorf("get services: %w", err) return "", fmt.Errorf("get services: %w", err)
} }
var service *reverseproxy.Service var service *rpservice.Service
for _, svc := range services { for _, svc := range services {
if svc.Domain == domain { if svc.Domain == domain {
service = svc service = svc
@@ -925,8 +945,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain) return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
} }
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID) services, err := s.serviceManager.GetAccountServices(ctx, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account services: %w", err) return nil, fmt.Errorf("get account services: %w", err)
} }
@@ -1047,8 +1067,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil }, nil
} }
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) { func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
services, err := s.reverseProxyManager.GetGlobalServices(ctx) services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("get services: %w", err) return nil, fmt.Errorf("get services: %w", err)
} }
@@ -1062,7 +1082,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
return nil, fmt.Errorf("service not found for domain: %s", domain) return nil, fmt.Errorf("service not found for domain: %s", domain)
} }
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error { func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled { if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil return nil
} }

View File

@@ -8,12 +8,12 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
type mockReverseProxyManager struct { type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.Service proxiesByAccount map[string][]*service.Service
err error err error
} }
@@ -21,31 +21,31 @@ func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, account
return nil return nil
} }
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
if m.err != nil { if m.err != nil {
return nil, m.err return nil, m.err
} }
return m.proxiesByAccount[accountID], nil return m.proxiesByAccount[accountID], nil
} }
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
return []*reverseproxy.Service{}, nil return []*service.Service{}, nil
} }
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error { func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
@@ -56,7 +56,7 @@ func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, ac
return nil return nil
} }
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error { func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error {
return nil return nil
} }
@@ -68,16 +68,16 @@ func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID,
return nil return nil
} }
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil return "", nil
} }
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return &reverseproxy.ExposeServiceResponse{}, nil return &service.ExposeServiceResponse{}, nil
} }
func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
@@ -111,7 +111,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name string name string
domain string domain string
userID string userID string
proxiesByAccount map[string][]*reverseproxy.Service proxiesByAccount map[string][]*service.Service
users map[string]*types.User users map[string]*types.User
proxyErr error proxyErr error
userErr error userErr error
@@ -122,7 +122,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not found", name: "user not found",
domain: "app.example.com", domain: "app.example.com",
userID: "unknown-user", userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}}, "account1": {{Domain: "app.example.com", AccountID: "account1"}},
}, },
users: map[string]*types.User{}, users: map[string]*types.User{},
@@ -133,7 +133,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy not found in user's account", name: "proxy not found in user's account",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{}, proxiesByAccount: map[string][]*service.Service{},
users: map[string]*types.User{ users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"}, "user1": {Id: "user1", AccountID: "account1"},
}, },
@@ -144,7 +144,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy exists in different account - not accessible", name: "proxy exists in different account - not accessible",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}}, "account2": {{Domain: "app.example.com", AccountID: "account2"}},
}, },
users: map[string]*types.User{ users: map[string]*types.User{
@@ -157,8 +157,8 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "no bearer auth configured - same account allows access", name: "no bearer auth configured - same account allows access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}}, "account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}},
}, },
users: map[string]*types.User{ users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"}, "user1": {Id: "user1", AccountID: "account1"},
@@ -169,12 +169,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth disabled - same account allows access", name: "bearer auth disabled - same account allows access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false}, BearerAuth: &service.BearerAuthConfig{Enabled: false},
}, },
}}, }},
}, },
@@ -187,12 +187,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth enabled but no groups configured - same account allows access", name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{}, DistributionGroups: []string{},
}, },
@@ -208,12 +208,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not in allowed groups", name: "user not in allowed groups",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"group1", "group2"}, DistributionGroups: []string{"group1", "group2"},
}, },
@@ -230,12 +230,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in one of the allowed groups - allow access", name: "user in one of the allowed groups - allow access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"group1", "group2"}, DistributionGroups: []string{"group1", "group2"},
}, },
@@ -251,12 +251,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in all allowed groups - allow access", name: "user in all allowed groups - allow access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"group1", "group2"}, DistributionGroups: []string{"group1", "group2"},
}, },
@@ -284,10 +284,10 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "multiple proxies in account - finds correct one", name: "multiple proxies in account - finds correct one",
domain: "app2.example.com", domain: "app2.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": { "account1": {
{Domain: "app1.example.com", AccountID: "account1"}, {Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}, {Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"}, {Domain: "app3.example.com", AccountID: "account1"},
}, },
}, },
@@ -301,7 +301,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{ server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{ serviceManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount, proxiesByAccount: tt.proxiesByAccount,
err: tt.proxyErr, err: tt.proxyErr,
}, },
@@ -328,7 +328,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name string name string
accountID string accountID string
domain string domain string
proxiesByAccount map[string][]*reverseproxy.Service proxiesByAccount map[string][]*service.Service
err error err error
expectProxy bool expectProxy bool
expectErr bool expectErr bool
@@ -337,7 +337,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy found", name: "proxy found",
accountID: "account1", accountID: "account1",
domain: "app.example.com", domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": { "account1": {
{Domain: "other.example.com", AccountID: "account1"}, {Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"}, {Domain: "app.example.com", AccountID: "account1"},
@@ -350,7 +350,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy not found in account", name: "proxy not found in account",
accountID: "account1", accountID: "account1",
domain: "unknown.example.com", domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}}, "account1": {{Domain: "app.example.com", AccountID: "account1"}},
}, },
expectProxy: false, expectProxy: false,
@@ -360,7 +360,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "empty proxy list for account", name: "empty proxy list for account",
accountID: "account1", accountID: "account1",
domain: "app.example.com", domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{}, proxiesByAccount: map[string][]*service.Service{},
expectProxy: false, expectProxy: false,
expectErr: true, expectErr: true,
}, },
@@ -378,7 +378,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{ server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{ serviceManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount, proxiesByAccount: tt.proxiesByAccount,
err: tt.err, err: tt.err,
}, },

View File

@@ -1,19 +1,73 @@
package grpc package grpc
import ( import (
"context"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
"sync"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
) )
type testProxyController struct {
mu sync.Mutex
clusterProxies map[string]map[string]struct{}
}
func newTestProxyController() *testProxyController {
return &testProxyController{
clusterProxies: make(map[string]map[string]struct{}),
}
}
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
}
func (c *testProxyController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return proxy.OIDCValidationConfig{}
}
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, clusterAddr, proxyID string) error {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.clusterProxies[clusterAddr]; !ok {
c.clusterProxies[clusterAddr] = make(map[string]struct{})
}
c.clusterProxies[clusterAddr][proxyID] = struct{}{}
return nil
}
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clusterAddr, proxyID string) error {
c.mu.Lock()
defer c.mu.Unlock()
if proxies, ok := c.clusterProxies[clusterAddr]; ok {
delete(proxies, proxyID)
}
return nil
}
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
c.mu.Lock()
defer c.mu.Unlock()
proxies, ok := c.clusterProxies[clusterAddr]
if !ok {
return nil
}
result := make([]string, 0, len(proxies))
for id := range proxies {
result = append(result, id)
}
return result
}
// registerFakeProxy adds a fake proxy connection to the server's internal maps // registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received. // and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse { func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
@@ -25,8 +79,7 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
} }
s.connectedProxies.Store(proxyID, conn) s.connectedProxies.Store(proxyID, conn)
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) _ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID)
proxySet.(*sync.Map).Store(proxyID, struct{}{})
return ch return ch
} }
@@ -41,12 +94,13 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda
} }
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour) tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
defer tokenStore.Close() require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
tokenStore: tokenStore, tokenStore: tokenStore,
} }
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com" const cluster = "proxy.example.com"
const numProxies = 3 const numProxies = 3
@@ -67,11 +121,7 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
}, },
} }
update := &proto.GetMappingUpdateResponse{ s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
Mapping: []*proto.ProxyMapping{mapping},
}
s.SendServiceUpdateToCluster(update, cluster)
tokens := make([]string, numProxies) tokens := make([]string, numProxies)
for i, ch := range channels { for i, ch := range channels {
@@ -101,12 +151,13 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
} }
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour) tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
defer tokenStore.Close() require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
tokenStore: tokenStore, tokenStore: tokenStore,
} }
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com" const cluster = "proxy.example.com"
ch1 := registerFakeProxy(s, "proxy-a", cluster) ch1 := registerFakeProxy(s, "proxy-a", cluster)
@@ -119,11 +170,7 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
Domain: "test.example.com", Domain: "test.example.com",
} }
update := &proto.GetMappingUpdateResponse{ s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
Mapping: []*proto.ProxyMapping{mapping},
}
s.SendServiceUpdateToCluster(update, cluster)
resp1 := drainChannel(ch1) resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2) resp2 := drainChannel(ch2)
@@ -135,18 +182,16 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
// Delete operations should not generate tokens // Delete operations should not generate tokens
assert.Empty(t, resp1.Mapping[0].AuthToken) assert.Empty(t, resp1.Mapping[0].AuthToken)
assert.Empty(t, resp2.Mapping[0].AuthToken) assert.Empty(t, resp2.Mapping[0].AuthToken)
// No tokens should have been created
assert.Equal(t, 0, tokenStore.GetTokenCount())
} }
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour) tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
defer tokenStore.Close() require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
tokenStore: tokenStore, tokenStore: tokenStore,
} }
s.SetProxyController(newTestProxyController())
// Register proxies in different clusters (SendServiceUpdate broadcasts to all) // Register proxies in different clusters (SendServiceUpdate broadcasts to all)
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a") ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")

View File

@@ -26,7 +26,7 @@ import (
"github.com/netbirdio/netbird/shared/management/client/common" "github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
@@ -82,7 +82,7 @@ type Server struct {
syncLimEnabled bool syncLimEnabled bool
syncLim int32 syncLim int32
reverseProxyManager reverseproxy.Manager reverseProxyManager rpservice.Manager
reverseProxyMu sync.RWMutex reverseProxyMu sync.RWMutex
} }

View File

@@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -34,11 +34,15 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir()) testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err) require.NoError(t, err)
proxyManager := &testValidateSessionProxyManager{store: testStore} serviceManager := &testValidateSessionServiceManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore} usersManager := &testValidateSessionUsersManager{store: testStore}
proxyManager := &testValidateSessionProxyManager{}
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager) tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
proxyService.SetProxyManager(proxyManager) require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore) createTestProxies(t, ctx, testStore)
@@ -54,7 +58,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
pubKey, privKey := generateSessionKeyPair(t) pubKey, privKey := generateSessionKeyPair(t)
testProxy := &reverseproxy.Service{ testProxy := &service.Service{
ID: "testProxyId", ID: "testProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "Test Proxy", Name: "Test Proxy",
@@ -62,15 +66,15 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true, Enabled: true,
SessionPrivateKey: privKey, SessionPrivateKey: privKey,
SessionPublicKey: pubKey, SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
}, },
}, },
} }
require.NoError(t, testStore.CreateService(ctx, testProxy)) require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{ restrictedProxy := &service.Service{
ID: "restrictedProxyId", ID: "restrictedProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "Restricted Proxy", Name: "Restricted Proxy",
@@ -78,8 +82,8 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true, Enabled: true,
SessionPrivateKey: privKey, SessionPrivateKey: privKey,
SessionPublicKey: pubKey, SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"allowedGroupId"}, DistributionGroups: []string{"allowedGroupId"},
}, },
@@ -239,79 +243,101 @@ func TestValidateSession_MissingToken(t *testing.T) {
assert.Contains(t, resp.DeniedReason, "missing") assert.Contains(t, resp.DeniedReason, "missing")
} }
type testValidateSessionProxyManager struct { type testValidateSessionServiceManager struct {
store store.Store store store.Store
} }
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error { func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) DeleteAllServices(_ context.Context, _, _ string) error { func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error { func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error { func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error { func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone) return m.store.GetServices(ctx, store.LockingStrengthNone)
} }
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID) return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
} }
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
} }
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil return "", nil
} }
func (m *testValidateSessionProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { func (m *testValidateSessionServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { func (m *testValidateSessionServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error { func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) StartExposeReaper(_ context.Context) {} func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
type testValidateSessionUsersManager struct { type testValidateSessionUsersManager struct {
store store.Store store store.Store

View File

@@ -15,7 +15,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
@@ -83,9 +83,9 @@ type DefaultAccountManager struct {
requestBuffer *AccountRequestBuffer requestBuffer *AccountRequestBuffer
proxyController port_forwarding.Controller proxyController port_forwarding.Controller
settingsManager settings.Manager settingsManager settings.Manager
reverseProxyManager reverseproxy.Manager serviceManager service.Manager
// config contains the management server configuration // config contains the management server configuration
config *nbconfig.Config config *nbconfig.Config
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
var _ account.Manager = (*DefaultAccountManager)(nil) var _ account.Manager = (*DefaultAccountManager)(nil)
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) { func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
am.reverseProxyManager = serviceManager am.serviceManager = serviceManager
} }
func isUniqueConstraintError(err error) bool { func isUniqueConstraintError(err error) bool {
@@ -395,7 +395,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
} }
if reloadReverseProxy { if reloadReverseProxy {
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil { if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err) log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
} }
} }
@@ -730,7 +730,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
} }
err = am.reverseProxyManager.DeleteAllServices(ctx, accountID, userID) err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err) return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
} }

View File

@@ -1,12 +1,14 @@
package account package account
//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import ( import (
"context" "context"
"net" "net"
"net/netip" "net/netip"
"time" "time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -61,11 +63,11 @@ type Manager interface {
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
@@ -140,5 +142,5 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
SetServiceManager(serviceManager reverseproxy.Manager) SetServiceManager(serviceManager service.Manager)
} }

File diff suppressed because it is too large Load Diff

View File

@@ -19,6 +19,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -27,8 +28,10 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -1803,12 +1806,12 @@ func TestAccount_Copy(t *testing.T) {
Address: "172.12.6.1/24", Address: "172.12.6.1/24",
}, },
}, },
Services: []*reverseproxy.Service{ Services: []*service.Service{
{ {
ID: "service1", ID: "service1",
Name: "test-service", Name: "test-service",
AccountID: "account1", AccountID: "account1",
Targets: []*reverseproxy.Target{}, Targets: []*service.Target{},
}, },
}, },
NetworkMapCache: &types.NetworkMapBuilder{}, NetworkMapCache: &types.NetworkMapBuilder{},
@@ -3113,6 +3116,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager) peersManager := peers.NewManager(store, permissionsManager)
proxyManager := proxy.NewMockManager(ctrl)
proxyManager.EXPECT().
CleanupStale(gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
ctx := context.Background() ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics) updateManager := update_channel.NewPeersUpdateManager(metrics)
@@ -3123,8 +3132,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err return nil, nil, err
} }
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil) proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, settingsMockManager, proxyGrpcServer, nil)) proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
return manager, updateManager, nil return manager, updateManager, nil
} }

View File

@@ -249,7 +249,15 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
switch storeEngine { switch storeEngine {
case types.SqliteStoreEngine: case types.SqliteStoreEngine:
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB)) dbFile := eventSinkDB
if envFile, ok := os.LookupEnv("NB_ACTIVITY_EVENT_SQLITE_FILE"); ok && envFile != "" {
dbFile = envFile
}
connStr := dbFile
if !filepath.IsAbs(dbFile) {
connStr = filepath.Join(dataDir, dbFile)
}
dialector = sqlite.Open(connStr)
case types.PostgresStoreEngine: case types.PostgresStoreEngine:
dsn, ok := os.LookupEnv(postgresDsnEnv) dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok { if !ok {

View File

@@ -425,6 +425,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
var groupIDsToDelete []string var groupIDsToDelete []string
var deletedGroups []*types.Group var deletedGroups []*types.Group
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
@@ -433,7 +438,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
continue continue
} }
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
allErrors = errors.Join(allErrors, err) allErrors = errors.Join(allErrors, err)
continue continue
} }
@@ -621,7 +626,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
return nil return nil
} }
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration { if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID) executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
@@ -641,6 +646,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
return &GroupLinkError{"network resource", group.Resources[0].ID} return &GroupLinkError{"network resource", group.Resources[0].ID}
} }
if slices.Contains(flowGroups, group.ID) {
return &GroupLinkError{"settings", "traffic event logging"}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)} return &GroupLinkError{"route", string(linkedRoute.NetID)}
} }

View File

@@ -12,6 +12,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang/mock/gomock"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -26,6 +27,7 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer" peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -284,6 +286,67 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
} }
} }
func TestDefaultAccountManager_DeleteGroupLinkedToFlowGroup(t *testing.T) {
am, _, err := createManager(t)
require.NoError(t, err)
ctrl := gomock.NewController(t)
settingsMock := settings.NewMockManager(ctrl)
settingsMock.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{FlowGroups: []string{"grp-for-flow"}}, nil).
AnyTimes()
settingsMock.EXPECT().
UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(false, nil).
AnyTimes()
am.settingsManager = settingsMock
_, account, err := initTestGroupAccount(am)
require.NoError(t, err)
grp := &types.Group{
ID: "grp-for-flow",
AccountID: account.Id,
Name: "Group for flow",
Issued: types.GroupIssuedAPI,
Peers: make([]string, 0),
}
require.NoError(t, am.CreateGroup(context.Background(), account.Id, groupAdminUserID, grp))
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, "grp-for-flow")
require.Error(t, err)
var gErr *GroupLinkError
require.ErrorAs(t, err, &gErr)
assert.Equal(t, "settings", gErr.Resource)
assert.Equal(t, "traffic event logging", gErr.Name)
group, err := am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
require.NoError(t, err)
assert.NotNil(t, group)
regularGrp := &types.Group{
ID: "grp-regular",
AccountID: account.Id,
Name: "Regular group",
Issued: types.GroupIssuedAPI,
Peers: make([]string, 0),
}
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, regularGrp)
require.NoError(t, err)
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, []string{"grp-for-flow", "grp-regular"})
require.Error(t, err)
group, err = am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
require.NoError(t, err)
assert.NotNil(t, group)
_, err = am.GetGroup(context.Background(), account.Id, "grp-regular", groupAdminUserID)
assert.Error(t, err)
}
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) { func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) {
accountID := "testingAcc" accountID := "testingAcc"
domain := "example.com" domain := "example.com"
@@ -703,7 +766,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving group linked to network router", func(t *testing.T) { t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store) permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager) groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager) resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager) routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager) networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp" idpmanager "github.com/netbirdio/netbird/management/server/idp"
@@ -73,7 +73,7 @@ const (
) )
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints // Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil { if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
idp.AddEndpoints(accountManager, router) idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router) instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router)
if reverseProxyManager != nil && reverseProxyDomainManager != nil { if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
} }
// Register OAuth callback handler for proxy authentication // Register OAuth callback handler for proxy authentication

View File

@@ -18,8 +18,8 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -190,7 +190,8 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcServer := newFakeOIDCServer() oidcServer := newFakeOIDCServer()
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute) tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
usersManager := users.NewManager(testStore) usersManager := users.NewManager(testStore)
@@ -208,9 +209,10 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcConfig, oidcConfig,
nil, nil,
usersManager, usersManager,
nil,
) )
proxyService.SetProxyManager(&testServiceManager{store: testStore}) proxyService.SetServiceManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService, nil) handler := NewAuthCallbackHandler(proxyService, nil)
@@ -239,12 +241,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
pubKey := base64.StdEncoding.EncodeToString(pub) pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv) privKey := base64.StdEncoding.EncodeToString(priv)
testProxy := &reverseproxy.Service{ testProxy := &service.Service{
ID: "testProxyId", ID: "testProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "Test Proxy", Name: "Test Proxy",
Domain: "test-proxy.example.com", Domain: "test-proxy.example.com",
Targets: []*reverseproxy.Target{{ Targets: []*service.Target{{
Path: strPtr("/"), Path: strPtr("/"),
Host: "localhost", Host: "localhost",
Port: 8080, Port: 8080,
@@ -254,8 +256,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true, Enabled: true,
}}, }},
Enabled: true, Enabled: true,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"allowedGroupId"}, DistributionGroups: []string{"allowedGroupId"},
}, },
@@ -265,12 +267,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
} }
require.NoError(t, testStore.CreateService(ctx, testProxy)) require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{ restrictedProxy := &service.Service{
ID: "restrictedProxyId", ID: "restrictedProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "Restricted Proxy", Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com", Domain: "restricted-proxy.example.com",
Targets: []*reverseproxy.Target{{ Targets: []*service.Target{{
Path: strPtr("/"), Path: strPtr("/"),
Host: "localhost", Host: "localhost",
Port: 8080, Port: 8080,
@@ -280,8 +282,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true, Enabled: true,
}}, }},
Enabled: true, Enabled: true,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"restrictedGroupId"}, DistributionGroups: []string{"restrictedGroupId"},
}, },
@@ -291,12 +293,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
} }
require.NoError(t, testStore.CreateService(ctx, restrictedProxy)) require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
noAuthProxy := &reverseproxy.Service{ noAuthProxy := &service.Service{
ID: "noAuthProxyId", ID: "noAuthProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "No Auth Proxy", Name: "No Auth Proxy",
Domain: "no-auth-proxy.example.com", Domain: "no-auth-proxy.example.com",
Targets: []*reverseproxy.Target{{ Targets: []*service.Target{{
Path: strPtr("/"), Path: strPtr("/"),
Host: "localhost", Host: "localhost",
Port: 8080, Port: 8080,
@@ -306,8 +308,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true, Enabled: true,
}}, }},
Enabled: true, Enabled: true,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: false, Enabled: false,
}, },
}, },
@@ -361,19 +363,19 @@ func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, u
return nil return nil
} }
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) { func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) { func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil return nil, nil
} }
@@ -385,7 +387,7 @@ func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ stri
return nil return nil
} }
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error { func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil return nil
} }
@@ -397,15 +399,15 @@ func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error
return nil return nil
} }
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone) return m.store.GetServices(ctx, store.LockingStrengthNone)
} }
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) { func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID) return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
} }
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
} }
@@ -413,7 +415,7 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri
return "", nil return "", nil
} }
func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return nil, nil return nil, nil
} }

View File

@@ -9,10 +9,13 @@ import (
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -91,12 +94,24 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
} }
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager) if err != nil {
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager) t.Fatalf("Failed to create proxy token store: %v", err)
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, settingsManager, proxyServiceServer, domainManager) }
proxyServiceServer.SetProxyManager(reverseProxyManager) noopMeter := noop.NewMeterProvider().Meter("")
am.SetServiceManager(reverseProxyManager) proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err)
}
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager)
proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked // @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
@@ -114,7 +129,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil) apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create API handler: %v", err) t.Fatalf("Failed to create API handler: %v", err)
} }

View File

@@ -52,7 +52,7 @@ type EmbeddedIdPConfig struct {
// EmbeddedStorageConfig holds storage configuration for the embedded IdP. // EmbeddedStorageConfig holds storage configuration for the embedded IdP.
type EmbeddedStorageConfig struct { type EmbeddedStorageConfig struct {
// Type is the storage type (currently only "sqlite3" is supported) // Type is the storage type: "sqlite3" (default) or "postgres"
Type string Type string
// Config contains type-specific configuration // Config contains type-specific configuration
Config EmbeddedStorageTypeConfig Config EmbeddedStorageTypeConfig
@@ -62,6 +62,8 @@ type EmbeddedStorageConfig struct {
type EmbeddedStorageTypeConfig struct { type EmbeddedStorageTypeConfig struct {
// File is the path to the SQLite database file (for sqlite3 type) // File is the path to the SQLite database file (for sqlite3 type)
File string File string
// DSN is the connection string for postgres
DSN string
} }
// OwnerConfig represents the initial owner/admin user for the embedded IdP. // OwnerConfig represents the initial owner/admin user for the embedded IdP.
@@ -74,6 +76,22 @@ type OwnerConfig struct {
Username string Username string
} }
// buildIdpStorageConfig builds the Dex storage config map based on the storage type.
func buildIdpStorageConfig(storageType string, cfg EmbeddedStorageTypeConfig) (map[string]interface{}, error) {
switch storageType {
case "sqlite3":
return map[string]interface{}{
"file": cfg.File,
}, nil
case "postgres":
return map[string]interface{}{
"dsn": cfg.DSN,
}, nil
default:
return nil, fmt.Errorf("unsupported IdP storage type: %s", storageType)
}
}
// ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig. // ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig.
func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
if c.Issuer == "" { if c.Issuer == "" {
@@ -85,6 +103,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" { if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" {
return nil, fmt.Errorf("storage file is required for sqlite3") return nil, fmt.Errorf("storage file is required for sqlite3")
} }
if c.Storage.Type == "postgres" && c.Storage.Config.DSN == "" {
return nil, fmt.Errorf("storage DSN is required for postgres")
}
storageConfig, err := buildIdpStorageConfig(c.Storage.Type, c.Storage.Config)
if err != nil {
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
}
// Build CLI redirect URIs including the device callback (both relative and absolute) // Build CLI redirect URIs including the device callback (both relative and absolute)
cliRedirectURIs := c.CLIRedirectURIs cliRedirectURIs := c.CLIRedirectURIs
@@ -100,10 +126,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
cfg := &dex.YAMLConfig{ cfg := &dex.YAMLConfig{
Issuer: c.Issuer, Issuer: c.Issuer,
Storage: dex.Storage{ Storage: dex.Storage{
Type: c.Storage.Type, Type: c.Storage.Type,
Config: map[string]interface{}{ Config: storageConfig,
"file": c.Storage.Config.File,
},
}, },
Web: dex.Web{ Web: dex.Web{
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},

View File

@@ -14,7 +14,7 @@ import (
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/idp/dex" "github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -294,9 +294,9 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
localUsers++ localUsers++
} else { } else {
idpUsers++ idpUsers++
idpType := extractIdpType(idpID)
embeddedIdpTypes[idpType]++
} }
idpType := extractIdpType(idpID)
embeddedIdpTypes[idpType]++
} }
} }
} }
@@ -358,12 +358,12 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
} }
servicesTargets += len(service.Targets) servicesTargets += len(service.Targets)
switch reverseproxy.ProxyStatus(service.Meta.Status) { switch rpservice.Status(service.Meta.Status) {
case reverseproxy.StatusActive: case rpservice.StatusActive:
servicesStatusActive++ servicesStatusActive++
case reverseproxy.StatusPending: case rpservice.StatusPending:
servicesStatusPending++ servicesStatusPending++
case reverseproxy.StatusError, reverseproxy.StatusCertificateFailed, reverseproxy.StatusTunnelNotCreated: case rpservice.StatusError, rpservice.StatusCertificateFailed, rpservice.StatusTunnelNotCreated:
servicesStatusError++ servicesStatusError++
} }
@@ -531,6 +531,9 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
// Connector IDs are formatted as "<type>-<xid>" (e.g., "okta-abc123", "zitadel-xyz"). // Connector IDs are formatted as "<type>-<xid>" (e.g., "okta-abc123", "zitadel-xyz").
// Returns the type prefix, or "oidc" if no known prefix is found. // Returns the type prefix, or "oidc" if no known prefix is found.
func extractIdpType(connectorID string) string { func extractIdpType(connectorID string) string {
if connectorID == "local" {
return "local"
}
idx := strings.LastIndex(connectorID, "-") idx := strings.LastIndex(connectorID, "-")
if idx <= 0 { if idx <= 0 {
return "oidc" return "oidc"

View File

@@ -6,7 +6,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/idp/dex" "github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -29,6 +29,7 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
localUserID := dex.EncodeDexUserID("10", "local") localUserID := dex.EncodeDexUserID("10", "local")
idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0") idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0")
oidcUserID := dex.EncodeDexUserID("30", "d6jvvp69kmnc73c9pl40")
return []*types.Account{ return []*types.Account{
{ {
Id: "1", Id: "1",
@@ -116,29 +117,29 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
}, },
}, },
}, },
Services: []*reverseproxy.Service{ Services: []*rpservice.Service{
{ {
ID: "svc1", ID: "svc1",
Enabled: true, Enabled: true,
Targets: []*reverseproxy.Target{ Targets: []*rpservice.Target{
{TargetType: "peer"}, {TargetType: "peer"},
{TargetType: "host"}, {TargetType: "host"},
}, },
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{Enabled: true}, PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
}, },
Meta: reverseproxy.ServiceMeta{Status: string(reverseproxy.StatusActive)}, Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
}, },
{ {
ID: "svc2", ID: "svc2",
Enabled: false, Enabled: false,
Targets: []*reverseproxy.Target{ Targets: []*rpservice.Target{
{TargetType: "domain"}, {TargetType: "domain"},
}, },
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: true}, BearerAuth: &rpservice.BearerAuthConfig{Enabled: true},
}, },
Meta: reverseproxy.ServiceMeta{Status: string(reverseproxy.StatusPending)}, Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
}, },
}, },
}, },
@@ -206,6 +207,13 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
"1": {}, "1": {},
}, },
}, },
oidcUserID: {
Id: oidcUserID,
IsServiceUser: false,
PATs: map[string]*types.PersonalAccessToken{
"1": {},
},
},
}, },
Networks: []*networkTypes.Network{ Networks: []*networkTypes.Network{
{ {
@@ -278,14 +286,14 @@ func TestGenerateProperties(t *testing.T) {
if properties["rules"] != 4 { if properties["rules"] != 4 {
t.Errorf("expected 4 rules, got %d", properties["rules"]) t.Errorf("expected 4 rules, got %d", properties["rules"])
} }
if properties["users"] != 2 { if properties["users"] != 3 {
t.Errorf("expected 1 users, got %d", properties["users"]) t.Errorf("expected 3 users, got %d", properties["users"])
} }
if properties["setup_keys_usage"] != 2 { if properties["setup_keys_usage"] != 2 {
t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"]) t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"])
} }
if properties["pats"] != 4 { if properties["pats"] != 5 {
t.Errorf("expected 4 personal_access_tokens, got %d", properties["pats"]) t.Errorf("expected 5 personal_access_tokens, got %d", properties["pats"])
} }
if properties["peers_ssh_enabled"] != 2 { if properties["peers_ssh_enabled"] != 2 {
t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"]) t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"])
@@ -369,14 +377,20 @@ func TestGenerateProperties(t *testing.T) {
if properties["local_users_count"] != 1 { if properties["local_users_count"] != 1 {
t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"]) t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"])
} }
if properties["idp_users_count"] != 1 { if properties["idp_users_count"] != 2 {
t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"]) t.Errorf("expected 2 idp_users_count, got %d", properties["idp_users_count"])
}
if properties["embedded_idp_users_local"] != 1 {
t.Errorf("expected 1 embedded_idp_users_local, got %v", properties["embedded_idp_users_local"])
} }
if properties["embedded_idp_users_zitadel"] != 1 { if properties["embedded_idp_users_zitadel"] != 1 {
t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"]) t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"])
} }
if properties["embedded_idp_count"] != 1 { if properties["embedded_idp_users_oidc"] != 1 {
t.Errorf("expected 1 embedded_idp_count, got %v", properties["embedded_idp_count"]) t.Errorf("expected 1 embedded_idp_users_oidc, got %v", properties["embedded_idp_users_oidc"])
}
if properties["embedded_idp_count"] != 3 {
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
} }
if properties["services"] != 2 { if properties["services"] != 2 {
@@ -436,7 +450,8 @@ func TestExtractIdpType(t *testing.T) {
{"microsoft-abc123", "microsoft"}, {"microsoft-abc123", "microsoft"},
{"authentik-abc123", "authentik"}, {"authentik-abc123", "authentik"},
{"keycloak-d5uv82dra0haedlf6kv0", "keycloak"}, {"keycloak-d5uv82dra0haedlf6kv0", "keycloak"},
{"local", "oidc"}, {"local", "local"},
{"d6jvvp69kmnc73c9pl40", "oidc"},
{"", "oidc"}, {"", "oidc"},
} }

View File

@@ -12,7 +12,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
@@ -148,7 +148,7 @@ type MockAccountManager struct {
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
} }
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) { func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) {
// Mock implementation - no-op // Mock implementation - no-op
} }

View File

@@ -7,7 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
@@ -33,23 +33,23 @@ type Manager interface {
} }
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
permissionsManager permissions.Manager permissionsManager permissions.Manager
groupsManager groups.Manager groupsManager groups.Manager
accountManager account.Manager accountManager account.Manager
reverseProxyManager reverseproxy.Manager serviceManager service.Manager
} }
type mockManager struct { type mockManager struct {
} }
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager { func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
groupsManager: groupsManager, groupsManager: groupsManager,
accountManager: accountManager, accountManager: accountManager,
reverseProxyManager: reverseproxyManager, serviceManager: reverseproxyManager,
} }
} }
@@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them // TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
go func() { go func() {
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID) err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err) log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
} }
@@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID) serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err) return fmt.Errorf("failed to check if resource is used by service: %w", err)
} }

View File

@@ -7,7 +7,7 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -31,8 +31,8 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err) require.NoError(t, err)
@@ -54,8 +54,8 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.Error(t, err) require.Error(t, err)
@@ -76,8 +76,8 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.NoError(t, err) require.NoError(t, err)
@@ -98,8 +98,8 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.Error(t, err) require.Error(t, err)
@@ -123,8 +123,8 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err) require.NoError(t, err)
@@ -147,8 +147,8 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err) require.Error(t, err)
@@ -176,9 +176,9 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes() serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.NoError(t, err) require.NoError(t, err)
@@ -205,8 +205,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -234,8 +234,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -262,8 +262,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -294,9 +294,9 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes() serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.NoError(t, err) require.NoError(t, err)
@@ -329,8 +329,8 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -361,8 +361,8 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -392,8 +392,8 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -416,9 +416,9 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes() serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err) require.NoError(t, err)
@@ -440,8 +440,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err) require.Error(t, err)

View File

@@ -493,7 +493,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var settings *types.Settings var settings *types.Settings
var eventsToStore []func() var eventsToStore []func()
serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID) serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil { if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err) return fmt.Errorf("failed to check if resource is used by service: %w", err)
} }

View File

@@ -352,9 +352,10 @@ func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest)
p.Name = a.Name p.Name = a.Name
p.Key = a.WgPubKey p.Key = a.WgPubKey
p.Meta = PeerSystemMeta{ p.Meta = PeerSystemMeta{
Hostname: a.Name, Hostname: a.Name,
GoOS: "js", GoOS: "js",
OS: "js", OS: "js",
KernelVersion: "wasm",
} }
} }

View File

@@ -28,9 +28,10 @@ import (
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -131,8 +132,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{}, &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{}, &accesslogs.AccessLogEntry{}, &proxy.Proxy{},
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err) return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -2075,7 +2076,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
return checks, nil return checks, nil
} }
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth, const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster, meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key pass_host_header, rewrite_redirects, session_private_key, session_public_key
@@ -2090,8 +2091,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
return nil, err return nil, err
} }
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) { services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
var s reverseproxy.Service var s rpservice.Service
var auth []byte var auth []byte
var createdAt, certIssuedAt sql.NullTime var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
@@ -2121,7 +2122,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
} }
} }
s.Meta = reverseproxy.ServiceMeta{} s.Meta = rpservice.Meta{}
if createdAt.Valid { if createdAt.Valid {
s.Meta.CreatedAt = createdAt.Time s.Meta.CreatedAt = createdAt.Time
} }
@@ -2142,7 +2143,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
s.SessionPublicKey = sessionPublicKey.String s.SessionPublicKey = sessionPublicKey.String
} }
s.Targets = []*reverseproxy.Target{} s.Targets = []*rpservice.Target{}
return &s, nil return &s, nil
}) })
if err != nil { if err != nil {
@@ -2154,7 +2155,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
} }
serviceIDs := make([]string, len(services)) serviceIDs := make([]string, len(services))
serviceMap := make(map[string]*reverseproxy.Service) serviceMap := make(map[string]*rpservice.Service)
for i, s := range services { for i, s := range services {
serviceIDs[i] = s.ID serviceIDs[i] = s.ID
serviceMap[s.ID] = s serviceMap[s.ID] = s
@@ -2165,8 +2166,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
return nil, err return nil, err
} }
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) { targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) {
var t reverseproxy.Target var t rpservice.Target
var path sql.NullString var path sql.NullString
err := row.Scan( err := row.Scan(
&t.ID, &t.ID,
@@ -2728,14 +2729,28 @@ func (s *SqlStore) GetStoreEngine() types.Engine {
// NewSqliteStore creates a new SQLite store. // NewSqliteStore creates a new SQLite store.
func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) { func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName) storeFile := storeSqliteFileName
if runtime.GOOS == "windows" { if envFile, ok := os.LookupEnv("NB_STORE_ENGINE_SQLITE_FILE"); ok && envFile != "" {
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows storeFile = envFile
storeStr = storeSqliteFileName
} }
file := filepath.Join(dataDir, storeStr) // Separate file path from any SQLite URI query parameters (e.g., "store.db?mode=rwc")
db, err := gorm.Open(sqlite.Open(file), getGormConfig()) filePath, query, hasQuery := strings.Cut(storeFile, "?")
connStr := filePath
if !filepath.IsAbs(filePath) {
connStr = filepath.Join(dataDir, filePath)
}
// Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows
if hasQuery {
connStr += "?" + query
} else if runtime.GOOS != "windows" {
// To avoid `The process cannot access the file because it is being used by another process` on Windows
connStr += "?cache=shared"
}
db, err := gorm.Open(sqlite.Open(connStr), getGormConfig())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -4838,7 +4853,7 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
return peerID, nil return peerID, nil
} }
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error { func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error {
serviceCopy := service.Copy() serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err) return fmt.Errorf("encrypt service data: %w", err)
@@ -4852,16 +4867,19 @@ func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Serv
return nil return nil
} }
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error { func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error {
serviceCopy := service.Copy() serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err) return fmt.Errorf("encrypt service data: %w", err)
} }
// Create target type instance outside transaction to avoid variable shadowing
targetType := &rpservice.Target{}
// Use a transaction to ensure atomic updates of the service and its targets // Use a transaction to ensure atomic updates of the service and its targets
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
// Delete existing targets // Delete existing targets
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil { if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil {
return err return err
} }
@@ -4882,7 +4900,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Serv
} }
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error { func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID) result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error) log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete service from store") return status.Errorf(status.Internal, "failed to delete service from store")
@@ -4895,13 +4913,53 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
return nil return nil
} }
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) { func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error {
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete target from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "target not found for service %s", serviceID)
}
return nil
}
func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error {
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete targets from store")
}
return nil
}
// GetTargetsByServiceID retrieves all targets for a given service
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) {
var targets []*rpservice.Target
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
result := tx.Where("account_id = ? AND service_id = ?", accountID, serviceID).Find(&targets)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get targets from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get targets from store")
}
return targets, nil
}
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) {
tx := s.db.Preload("Targets") tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var service *reverseproxy.Service var service *rpservice.Service
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID) result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -4919,30 +4977,8 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
return service, nil return service, nil
} }
func (s *SqlStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
tx := s.db.Preload("Targets") var service *rpservice.Service
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var serviceList []*reverseproxy.Service
result := tx.Find(&serviceList, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get services from store")
}
for _, service := range serviceList {
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt service data: %w", err)
}
}
return serviceList, nil
}
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
var service *reverseproxy.Service
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service) result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -4960,13 +4996,13 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str
return service, nil return service, nil
} }
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) { func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets") tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var serviceList []*reverseproxy.Service var serviceList []*rpservice.Service
result := tx.Find(&serviceList) result := tx.Find(&serviceList)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error) log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
@@ -4982,13 +5018,13 @@ func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength
return serviceList, nil return serviceList, nil
} }
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets") tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var serviceList []*reverseproxy.Service var serviceList []*rpservice.Service
result := tx.Find(&serviceList, accountIDCondition, accountID) result := tx.Find(&serviceList, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error) log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
@@ -5004,6 +5040,99 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS
return serviceList, nil return serviceList, nil
} }
// RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service.
func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Update("meta_last_renewed_at", time.Now())
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error)
return status.Errorf(status.Internal, "renew ephemeral service")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
}
return nil
}
// GetExpiredEphemeralServices returns ephemeral services whose last renewal exceeds the given TTL.
// Only the fields needed for reaping are selected. The limit parameter caps the batch size to
// avoid loading too many rows in a single tick. Rows with empty source_peer are excluded to
// skip malformed legacy data.
func (s *SqlStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) {
cutoff := time.Now().Add(-ttl)
var services []*rpservice.Service
result := s.db.
Select("id", "account_id", "source_peer", "domain").
Where("source = ? AND source_peer <> '' AND meta_last_renewed_at < ?", rpservice.SourceEphemeral, cutoff).
Limit(limit).
Find(&services)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get expired ephemeral services: %v", result.Error)
return nil, status.Errorf(status.Internal, "get expired ephemeral services")
}
return services, nil
}
// CountEphemeralServicesByPeer returns the count of ephemeral services for a specific peer.
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
// The locking is applied via a row-level SELECT ... FOR UPDATE (not on the aggregate) to
// stay compatible with Postgres, which disallows FOR UPDATE on COUNT(*).
func (s *SqlStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
if lockStrength == LockingStrengthNone {
var count int64
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
Count(&count)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
return 0, status.Errorf(status.Internal, "count ephemeral services")
}
return count, nil
}
var ids []string
result := s.db.Model(&rpservice.Service{}).
Clauses(clause.Locking{Strength: string(lockStrength)}).
Select("id").
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
Pluck("id", &ids)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
return 0, status.Errorf(status.Internal, "count ephemeral services")
}
return int64(len(ids)), nil
}
// EphemeralServiceExists checks if an ephemeral service exists for the given peer and domain.
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
if lockStrength == LockingStrengthNone {
var count int64
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Count(&count)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
return false, status.Errorf(status.Internal, "check ephemeral service existence")
}
return count > 0, nil
}
var id string
result := s.db.Model(&rpservice.Service{}).
Clauses(clause.Locking{Strength: string(lockStrength)}).
Select("id").
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Limit(1).
Pluck("id", &id)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
return false, status.Errorf(status.Internal, "check ephemeral service existence")
}
return id != "", nil
}
func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) { func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) {
tx := s.db tx := s.db
@@ -5216,13 +5345,13 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
return query return query
} }
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) { func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) {
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var target *reverseproxy.Target var target *rpservice.Target
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID) result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -5235,3 +5364,65 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
return target, nil return target, nil
} }
// SaveProxy saves or updates a proxy in the database
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
result := s.db.WithContext(ctx).Save(p)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to save proxy")
}
return nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", time.Now())
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
return status.Errorf(status.Internal, "failed to update proxy heartbeat")
}
return nil
}
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
var addresses []string
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses")
}
return addresses, nil
}
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)
result := s.db.WithContext(ctx).
Where("last_seen < ?", cutoffTime).
Delete(&proxy.Proxy{})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error)
return status.Errorf(status.Internal, "failed to cleanup stale proxies")
}
if result.RowsAffected > 0 {
log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected)
}
return nil
}

View File

@@ -20,7 +20,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -264,7 +264,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
&types.Policy{}, &types.PolicyRule{}, &route.Route{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{},
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &types.AccountOnboarding{}, &service.Service{}, &service.Target{},
} }
for i := len(models) - 1; i >= 0; i-- { for i := len(models) - 1; i >= 0; i-- {

View File

@@ -25,9 +25,10 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -252,14 +253,18 @@ type Store interface {
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error) GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
CreateService(ctx context.Context, service *reverseproxy.Service) error CreateService(ctx context.Context, service *rpservice.Service) error
UpdateService(ctx context.Context, service *reverseproxy.Service) error UpdateService(ctx context.Context, service *rpservice.Service) error
DeleteService(ctx context.Context, accountID, serviceID string) error DeleteService(ctx context.Context, accountID, serviceID string) error
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error
GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error)
CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error)
EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error) ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
@@ -271,9 +276,16 @@ type Store interface {
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error)
GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error)
DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
// GetCustomDomainsCounts returns the total and validated custom domain counts.
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
} }

View File

@@ -12,9 +12,10 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
dns "github.com/netbirdio/netbird/dns" dns "github.com/netbirdio/netbird/dns"
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
zones "github.com/netbirdio/netbird/management/internals/modules/zones" zones "github.com/netbirdio/netbird/management/internals/modules/zones"
records "github.com/netbirdio/netbird/management/internals/modules/zones/records" records "github.com/netbirdio/netbird/management/internals/modules/zones/records"
types "github.com/netbirdio/netbird/management/server/networks/resources/types" types "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -150,6 +151,20 @@ func (mr *MockStoreMockRecorder) ApproveAccountPeers(ctx, accountID interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID)
} }
// CleanupStaleProxies mocks base method.
func (m *MockStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupStaleProxies", ctx, inactivityDuration)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupStaleProxies indicates an expected call of CleanupStaleProxies.
func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
}
// Close mocks base method. // Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error { func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -193,6 +208,21 @@ func (mr *MockStoreMockRecorder) CountAccountsByPrivateDomain(ctx, domain interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountsByPrivateDomain", reflect.TypeOf((*MockStore)(nil).CountAccountsByPrivateDomain), ctx, domain) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountsByPrivateDomain", reflect.TypeOf((*MockStore)(nil).CountAccountsByPrivateDomain), ctx, domain)
} }
// CountEphemeralServicesByPeer mocks base method.
func (m *MockStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountEphemeralServicesByPeer", ctx, lockStrength, accountID, peerID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountEphemeralServicesByPeer indicates an expected call of CountEphemeralServicesByPeer.
func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
}
// CreateAccessLog mocks base method. // CreateAccessLog mocks base method.
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error { func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -293,7 +323,7 @@ func (mr *MockStoreMockRecorder) CreatePolicy(ctx, policy interface{}) *gomock.C
} }
// CreateService mocks base method. // CreateService mocks base method.
func (m *MockStore) CreateService(ctx context.Context, service *reverseproxy.Service) error { func (m *MockStore) CreateService(ctx context.Context, service *service.Service) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateService", ctx, service) ret := m.ctrl.Call(m, "CreateService", ctx, service)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -559,6 +589,20 @@ func (mr *MockStoreMockRecorder) DeleteService(ctx, accountID, serviceID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockStore)(nil).DeleteService), ctx, accountID, serviceID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockStore)(nil).DeleteService), ctx, accountID, serviceID)
} }
// DeleteServiceTargets mocks base method.
func (m *MockStore) DeleteServiceTargets(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteServiceTargets", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteServiceTargets indicates an expected call of DeleteServiceTargets.
func (mr *MockStoreMockRecorder) DeleteServiceTargets(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteServiceTargets", reflect.TypeOf((*MockStore)(nil).DeleteServiceTargets), ctx, accountID, serviceID)
}
// DeleteSetupKey mocks base method. // DeleteSetupKey mocks base method.
func (m *MockStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error { func (m *MockStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -573,6 +617,20 @@ func (mr *MockStoreMockRecorder) DeleteSetupKey(ctx, accountID, keyID interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSetupKey", reflect.TypeOf((*MockStore)(nil).DeleteSetupKey), ctx, accountID, keyID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSetupKey", reflect.TypeOf((*MockStore)(nil).DeleteSetupKey), ctx, accountID, keyID)
} }
// DeleteTarget mocks base method.
func (m *MockStore) DeleteTarget(ctx context.Context, accountID, serviceID string, targetID uint) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteTarget", ctx, accountID, serviceID, targetID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteTarget indicates an expected call of DeleteTarget.
func (mr *MockStoreMockRecorder) DeleteTarget(ctx, accountID, serviceID, targetID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTarget", reflect.TypeOf((*MockStore)(nil).DeleteTarget), ctx, accountID, serviceID, targetID)
}
// DeleteTokenID2UserIDIndex mocks base method. // DeleteTokenID2UserIDIndex mocks base method.
func (m *MockStore) DeleteTokenID2UserIDIndex(tokenID string) error { func (m *MockStore) DeleteTokenID2UserIDIndex(tokenID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -643,6 +701,21 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
} }
// EphemeralServiceExists mocks base method.
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EphemeralServiceExists", ctx, lockStrength, accountID, peerID, domain)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// EphemeralServiceExists indicates an expected call of EphemeralServiceExists.
func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
}
// ExecuteInTransaction mocks base method. // ExecuteInTransaction mocks base method.
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error { func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -1095,10 +1168,10 @@ func (mr *MockStoreMockRecorder) GetAccountRoutes(ctx, lockStrength, accountID i
} }
// GetAccountServices mocks base method. // GetAccountServices mocks base method.
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*service.Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID) ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service) ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -1109,21 +1182,6 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
} }
// GetServicesByAccountID mocks base method.
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
}
// GetAccountSettings mocks base method. // GetAccountSettings mocks base method.
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) { func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -1214,6 +1272,21 @@ func (mr *MockStoreMockRecorder) GetAccountsCounter(ctx interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx)
} }
// GetActiveProxyClusterAddresses mocks base method.
func (m *MockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddresses", ctx)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveProxyClusterAddresses indicates an expected call of GetActiveProxyClusterAddresses.
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
}
// GetAllAccounts mocks base method. // GetAllAccounts mocks base method.
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account { func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -1288,6 +1361,22 @@ func (mr *MockStoreMockRecorder) GetCustomDomain(ctx, accountID, domainID interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomain", reflect.TypeOf((*MockStore)(nil).GetCustomDomain), ctx, accountID, domainID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomain", reflect.TypeOf((*MockStore)(nil).GetCustomDomain), ctx, accountID, domainID)
} }
// GetCustomDomainsCounts mocks base method.
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(int64)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
}
// GetDNSRecordByID mocks base method. // GetDNSRecordByID mocks base method.
func (m *MockStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) { func (m *MockStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -1303,6 +1392,21 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID)
} }
// GetExpiredEphemeralServices mocks base method.
func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetExpiredEphemeralServices", ctx, ttl, limit)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetExpiredEphemeralServices indicates an expected call of GetExpiredEphemeralServices.
func (mr *MockStoreMockRecorder) GetExpiredEphemeralServices(ctx, ttl, limit interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiredEphemeralServices", reflect.TypeOf((*MockStore)(nil).GetExpiredEphemeralServices), ctx, ttl, limit)
}
// GetGroupByID mocks base method. // GetGroupByID mocks base method.
func (m *MockStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types2.Group, error) { func (m *MockStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types2.Group, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -1828,10 +1932,10 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
} }
// GetServiceByDomain mocks base method. // GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain) ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
ret0, _ := ret[0].(*reverseproxy.Service) ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -1843,10 +1947,10 @@ func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain inter
} }
// GetServiceByID mocks base method. // GetServiceByID mocks base method.
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) { func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*service.Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID) ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID)
ret0, _ := ret[0].(*reverseproxy.Service) ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -1858,10 +1962,10 @@ func (mr *MockStoreMockRecorder) GetServiceByID(ctx, lockStrength, accountID, se
} }
// GetServiceTargetByTargetID mocks base method. // GetServiceTargetByTargetID mocks base method.
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*reverseproxy.Target, error) { func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*service.Target, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID) ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID)
ret0, _ := ret[0].(*reverseproxy.Target) ret0, _ := ret[0].(*service.Target)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -1872,27 +1976,11 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTargetByTargetID", reflect.TypeOf((*MockStore)(nil).GetServiceTargetByTargetID), ctx, lockStrength, accountID, targetID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTargetByTargetID", reflect.TypeOf((*MockStore)(nil).GetServiceTargetByTargetID), ctx, lockStrength, accountID, targetID)
} }
// GetCustomDomainsCounts mocks base method.
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(int64)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
}
// GetServices mocks base method. // GetServices mocks base method.
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) { func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*service.Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength) ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength)
ret0, _ := ret[0].([]*reverseproxy.Service) ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -1962,6 +2050,21 @@ func (mr *MockStoreMockRecorder) GetTakenIPs(ctx, lockStrength, accountId interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTakenIPs", reflect.TypeOf((*MockStore)(nil).GetTakenIPs), ctx, lockStrength, accountId) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTakenIPs", reflect.TypeOf((*MockStore)(nil).GetTakenIPs), ctx, lockStrength, accountId)
} }
// GetTargetsByServiceID mocks base method.
func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*service.Target, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTargetsByServiceID", ctx, lockStrength, accountID, serviceID)
ret0, _ := ret[0].([]*service.Target)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTargetsByServiceID indicates an expected call of GetTargetsByServiceID.
func (mr *MockStoreMockRecorder) GetTargetsByServiceID(ctx, lockStrength, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTargetsByServiceID", reflect.TypeOf((*MockStore)(nil).GetTargetsByServiceID), ctx, lockStrength, accountID, serviceID)
}
// GetTokenIDByHashedToken mocks base method. // GetTokenIDByHashedToken mocks base method.
func (m *MockStore) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) { func (m *MockStore) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -2343,6 +2446,20 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResourceFromGroup", reflect.TypeOf((*MockStore)(nil).RemoveResourceFromGroup), ctx, accountId, groupID, resourceID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResourceFromGroup", reflect.TypeOf((*MockStore)(nil).RemoveResourceFromGroup), ctx, accountId, groupID, resourceID)
} }
// RenewEphemeralService mocks base method.
func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// RenewEphemeralService indicates an expected call of RenewEphemeralService.
func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, domain)
}
// RevokeProxyAccessToken mocks base method. // RevokeProxyAccessToken mocks base method.
func (m *MockStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error { func (m *MockStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -2567,6 +2684,20 @@ func (mr *MockStoreMockRecorder) SavePostureChecks(ctx, postureCheck interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck)
} }
// SaveProxy mocks base method.
func (m *MockStore) SaveProxy(ctx context.Context, proxy *proxy.Proxy) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveProxy", ctx, proxy)
ret0, _ := ret[0].(error)
return ret0
}
// SaveProxy indicates an expected call of SaveProxy.
func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
}
// SaveProxyAccessToken mocks base method. // SaveProxyAccessToken mocks base method.
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error { func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -2762,8 +2893,22 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups)
} }
// UpdateProxyHeartbeat mocks base method.
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID)
}
// UpdateService mocks base method. // UpdateService mocks base method.
func (m *MockStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error { func (m *MockStore) UpdateService(ctx context.Context, service *service.Service) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateService", ctx, service) ret := m.ctrl.Call(m, "UpdateService", ctx, service)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)

View File

@@ -18,7 +18,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth" "github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -100,7 +100,7 @@ type Account struct {
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"` Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings // Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
@@ -906,7 +906,7 @@ func (a *Account) Copy() *Account {
networkResources = append(networkResources, resource.Copy()) networkResources = append(networkResources, resource.Copy())
} }
services := []*reverseproxy.Service{} services := []*service.Service{}
for _, service := range a.Services { for _, service := range a.Services {
services = append(services, service.Copy()) services = append(services, service.Copy())
} }
@@ -1814,7 +1814,7 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
} }
} }
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) { func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
for _, target := range service.Targets { for _, target := range service.Targets {
if !target.Enabled { if !target.Enabled {
continue continue
@@ -1823,7 +1823,7 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *rever
} }
} }
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) { func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
port, ok := a.resolveTargetPort(ctx, target) port, ok := a.resolveTargetPort(ctx, target)
if !ok { if !ok {
return return
@@ -1840,7 +1840,7 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *revers
} }
} }
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) { func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) {
if target.Port != 0 { if target.Port != 0 {
return target.Port, true return target.Port, true
} }
@@ -1856,7 +1856,7 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Ta
} }
} }
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy { func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path) policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
return &Policy{ return &Policy{
ID: policyID, ID: policyID,

View File

@@ -742,6 +742,11 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
if err != nil { if err != nil {
return false, nil, nil, nil, fmt.Errorf("failed to re-read initiator user in transaction: %w", err) return false, nil, nil, nil, fmt.Errorf("failed to re-read initiator user in transaction: %w", err)
} }
// Ensure the initiator still has admin privileges
if initiatorUser.HasAdminPower() && !freshInitiator.HasAdminPower() {
return false, nil, nil, nil, status.Errorf(status.PermissionDenied, "initiator role was changed during request processing")
}
initiatorUser = freshInitiator initiatorUser = freshInitiator
} }
@@ -872,10 +877,6 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse
return nil return nil
} }
if !initiatorUser.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only admins and owners can update users")
}
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
} }

View File

@@ -2032,27 +2032,6 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {
}) })
} }
func TestValidateUserUpdate_RejectsNonAdminInitiator(t *testing.T) {
groupsMap := map[string]*types.Group{}
initiator := &types.User{
Id: "initiator",
Role: types.UserRoleUser,
}
oldUser := &types.User{
Id: "target",
Role: types.UserRoleUser,
}
update := &types.User{
Id: "target",
Role: types.UserRoleOwner,
}
err := validateUserUpdate(groupsMap, initiator, oldUser, update)
require.Error(t, err, "regular user should not be able to promote to owner")
assert.Contains(t, err.Error(), "only admins and owners can update users")
}
func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) { func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) {
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err) require.NoError(t, err)
@@ -2109,7 +2088,7 @@ func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) {
}) })
require.Error(t, err, "processUserUpdate should reject stale initiator whose role was demoted") require.Error(t, err, "processUserUpdate should reject stale initiator whose role was demoted")
assert.Contains(t, err.Error(), "only admins and owners can update users") assert.Contains(t, err.Error(), "initiator role was changed during request processing")
targetUser, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetID) targetUser, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetID)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -42,6 +42,8 @@ var (
acmeCerts bool acmeCerts bool
acmeAddr string acmeAddr string
acmeDir string acmeDir string
acmeEABKID string
acmeEABHMACKey string
acmeChallengeType string acmeChallengeType string
debugEndpoint bool debugEndpoint bool
debugEndpointAddr string debugEndpointAddr string
@@ -74,6 +76,8 @@ func init() {
rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates automatically") rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates automatically")
rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges (only used when acme-challenge-type is http-01)") rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges (only used when acme-challenge-type is http-01)")
rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory") rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory")
rootCmd.Flags().StringVar(&acmeEABKID, "acme-eab-kid", envStringOrDefault("NB_PROXY_ACME_EAB_KID", ""), "ACME EAB KID for account registration")
rootCmd.Flags().StringVar(&acmeEABHMACKey, "acme-eab-hmac-key", envStringOrDefault("NB_PROXY_ACME_EAB_HMAC_KEY", ""), "ACME EAB HMAC key for account registration")
rootCmd.Flags().StringVar(&acmeChallengeType, "acme-challenge-type", envStringOrDefault("NB_PROXY_ACME_CHALLENGE_TYPE", "tls-alpn-01"), "ACME challenge type: tls-alpn-01 (default, port 443 only) or http-01 (requires port 80)") rootCmd.Flags().StringVar(&acmeChallengeType, "acme-challenge-type", envStringOrDefault("NB_PROXY_ACME_CHALLENGE_TYPE", "tls-alpn-01"), "ACME challenge type: tls-alpn-01 (default, port 443 only) or http-01 (requires port 80)")
rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint") rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint")
rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint") rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint")
@@ -149,6 +153,8 @@ func runServer(cmd *cobra.Command, args []string) error {
GenerateACMECertificates: acmeCerts, GenerateACMECertificates: acmeCerts,
ACMEChallengeAddress: acmeAddr, ACMEChallengeAddress: acmeAddr,
ACMEDirectory: acmeDir, ACMEDirectory: acmeDir,
ACMEEABKID: acmeEABKID,
ACMEEABHMACKey: acmeEABHMACKey,
ACMEChallengeType: acmeChallengeType, ACMEChallengeType: acmeChallengeType,
DebugEndpointEnabled: debugEndpoint, DebugEndpointEnabled: debugEndpoint,
DebugEndpointAddress: debugEndpointAddr, DebugEndpointAddress: debugEndpointAddr,

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/asn1" "encoding/asn1"
"encoding/base64"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
@@ -59,7 +60,10 @@ type Manager struct {
// NewManager creates a new ACME certificate manager. The certDir is used // NewManager creates a new ACME certificate manager. The certDir is used
// for caching certificates. The lockMethod controls cross-replica // for caching certificates. The lockMethod controls cross-replica
// coordination strategy (see CertLockMethod constants). // coordination strategy (see CertLockMethod constants).
func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager { // eabKID and eabHMACKey are optional External Account Binding credentials
// required for some CAs like ZeroSSL. The eabHMACKey should be the base64
// URL-encoded string provided by the CA.
func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager {
if logger == nil { if logger == nil {
logger = log.StandardLogger() logger = log.StandardLogger()
} }
@@ -70,10 +74,26 @@ func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *l
certNotifier: notifier, certNotifier: notifier,
logger: logger, logger: logger,
} }
var eab *acme.ExternalAccountBinding
if eabKID != "" && eabHMACKey != "" {
decodedKey, err := base64.RawURLEncoding.DecodeString(eabHMACKey)
if err != nil {
logger.Errorf("failed to decode EAB HMAC key: %v", err)
} else {
eab = &acme.ExternalAccountBinding{
KID: eabKID,
Key: decodedKey,
}
logger.Infof("configured External Account Binding with KID: %s", eabKID)
}
}
mgr.Manager = &autocert.Manager{ mgr.Manager = &autocert.Manager{
Prompt: autocert.AcceptTOS, Prompt: autocert.AcceptTOS,
HostPolicy: mgr.hostPolicy, HostPolicy: mgr.hostPolicy,
Cache: autocert.DirCache(certDir), Cache: autocert.DirCache(certDir),
ExternalAccountBinding: eab,
Client: &acme.Client{ Client: &acme.Client{
DirectoryURL: acmeURL, DirectoryURL: acmeURL,
}, },
@@ -136,7 +156,7 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) {
cert, err := mgr.GetCertificate(hello) cert, err := mgr.GetCertificate(hello)
elapsed := time.Since(start) elapsed := time.Since(start)
if err != nil { if err != nil {
mgr.logger.Warnf("prefetch certificate for domain %q: %v", name, err) mgr.logger.Warnf("prefetch certificate for domain %q in %s: %v", name, elapsed.String(), err)
mgr.setDomainState(d, domainFailed, err.Error()) mgr.setDomainState(d, domainFailed, err.Error())
return return
} }

View File

@@ -10,7 +10,7 @@ import (
) )
func TestHostPolicy(t *testing.T) { func TestHostPolicy(t *testing.T) {
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "") mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "")
mgr.AddDomain("example.com", "acc1", "rp1") mgr.AddDomain("example.com", "acc1", "rp1")
// Wait for the background prefetch goroutine to finish so the temp dir // Wait for the background prefetch goroutine to finish so the temp dir
@@ -70,7 +70,7 @@ func TestHostPolicy(t *testing.T) {
} }
func TestDomainStates(t *testing.T) { func TestDomainStates(t *testing.T) {
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "") mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "")
assert.Equal(t, 0, mgr.PendingCerts(), "initially zero") assert.Equal(t, 0, mgr.PendingCerts(), "initially zero")
assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains") assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains")

View File

@@ -81,9 +81,10 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
rp := &httputil.ReverseProxy{ rp := &httputil.ReverseProxy{
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader), Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
Transport: p.transport, Transport: p.transport,
ErrorHandler: proxyErrorHandler, FlushInterval: -1,
ErrorHandler: proxyErrorHandler,
} }
if result.rewriteRedirects { if result.rewriteRedirects {
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose

View File

@@ -18,8 +18,9 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -37,7 +38,7 @@ type integrationTestSetup struct {
grpcServer *grpc.Server grpcServer *grpc.Server
grpcAddr string grpcAddr string
cleanup func() cleanup func()
services []*reverseproxy.Service services []*service.Service
} }
func setupIntegrationTest(t *testing.T) *integrationTestSetup { func setupIntegrationTest(t *testing.T) *integrationTestSetup {
@@ -66,13 +67,13 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
privKey := base64.StdEncoding.EncodeToString(priv) privKey := base64.StdEncoding.EncodeToString(priv)
// Create test services in the store // Create test services in the store
services := []*reverseproxy.Service{ services := []*service.Service{
{ {
ID: "rp-1", ID: "rp-1",
AccountID: "test-account-1", AccountID: "test-account-1",
Name: "Test App 1", Name: "Test App 1",
Domain: "app1.test.proxy.io", Domain: "app1.test.proxy.io",
Targets: []*reverseproxy.Target{{ Targets: []*service.Target{{
Path: strPtr("/"), Path: strPtr("/"),
Host: "10.0.0.1", Host: "10.0.0.1",
Port: 8080, Port: 8080,
@@ -91,7 +92,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
AccountID: "test-account-1", AccountID: "test-account-1",
Name: "Test App 2", Name: "Test App 2",
Domain: "app2.test.proxy.io", Domain: "app2.test.proxy.io",
Targets: []*reverseproxy.Target{{ Targets: []*service.Target{{
Path: strPtr("/"), Path: strPtr("/"),
Host: "10.0.0.2", Host: "10.0.0.2",
Port: 8080, Port: 8080,
@@ -112,7 +113,8 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
} }
// Create real token store // Create real token store
tokenStore := nbgrpc.NewOneTimeTokenStore(5 * time.Minute) tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
// Create real users manager // Create real users manager
usersManager := users.NewManager(testStore) usersManager := users.NewManager(testStore)
@@ -124,17 +126,23 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
HMACKey: []byte("test-hmac-key"), HMACKey: []byte("test-hmac-key"),
} }
proxyManager := &testProxyManager{}
proxyService := nbgrpc.NewProxyServiceServer( proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{}, &testAccessLogManager{},
tokenStore, tokenStore,
oidcConfig, oidcConfig,
nil, nil,
usersManager, usersManager,
proxyManager,
) )
// Use store-backed service manager // Use store-backed service manager
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore} svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
proxyService.SetProxyManager(svcMgr) proxyService.SetServiceManager(svcMgr)
proxyController := &testProxyController{}
proxyService.SetProxyController(proxyController)
// Start real gRPC server // Start real gRPC server
lis, err := net.Listen("tcp", "127.0.0.1:0") lis, err := net.Listen("tcp", "127.0.0.1:0")
@@ -185,6 +193,52 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
return nil, 0, nil return nil, 0, nil
} }
// testProxyManager is a mock implementation of proxy.Manager for testing.
type testProxyManager struct{}
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testProxyManager) Heartbeat(_ context.Context, _ string) error {
return nil
}
func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
type testProxyController struct{}
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
// noop
}
func (c *testProxyController) GetOIDCValidationConfig() nbproxy.OIDCValidationConfig {
return nbproxy.OIDCValidationConfig{}
}
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, _, _ string) error {
return nil
}
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, _, _ string) error {
return nil
}
func (c *testProxyController) GetProxiesForCluster(_ string) []string {
return nil
}
// storeBackedServiceManager reads directly from the real store. // storeBackedServiceManager reads directly from the real store.
type storeBackedServiceManager struct { type storeBackedServiceManager struct {
store store.Store store store.Store
@@ -195,19 +249,19 @@ func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accou
return nil return nil
} }
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
} }
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) { func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
} }
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
@@ -219,7 +273,7 @@ func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context,
return nil return nil
} }
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error { func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
return nil return nil
} }
@@ -231,15 +285,15 @@ func (m *storeBackedServiceManager) ReloadService(ctx context.Context, accountID
return nil return nil
} }
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1") return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1")
} }
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) { func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
} }
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
} }
@@ -247,8 +301,8 @@ func (m *storeBackedServiceManager) GetServiceIDByTargetID(ctx context.Context,
return "", nil return "", nil
} }
func (m *storeBackedServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { func (m *storeBackedServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return &reverseproxy.ExposeServiceResponse{}, nil return &service.ExposeServiceResponse{}, nil
} }
func (m *storeBackedServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { func (m *storeBackedServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {

View File

@@ -84,6 +84,10 @@ type Server struct {
GenerateACMECertificates bool GenerateACMECertificates bool
ACMEChallengeAddress string ACMEChallengeAddress string
ACMEDirectory string ACMEDirectory string
// ACMEEABKID is the External Account Binding Key ID for CAs that require EAB (e.g., ZeroSSL).
ACMEEABKID string
// ACMEEABHMACKey is the External Account Binding HMAC key (base64 URL-encoded) for CAs that require EAB.
ACMEEABHMACKey string
// ACMEChallengeType specifies the ACME challenge type: "http-01" or "tls-alpn-01". // ACMEChallengeType specifies the ACME challenge type: "http-01" or "tls-alpn-01".
// Defaults to "tls-alpn-01" if not specified. // Defaults to "tls-alpn-01" if not specified.
ACMEChallengeType string ACMEChallengeType string
@@ -419,7 +423,7 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
"acme_server": s.ACMEDirectory, "acme_server": s.ACMEDirectory,
"challenge_type": s.ACMEChallengeType, "challenge_type": s.ACMEChallengeType,
}).Debug("ACME certificates enabled, configuring certificate manager") }).Debug("ACME certificates enabled, configuring certificate manager")
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod) s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod)
if s.ACMEChallengeType == "http-01" { if s.ACMEChallengeType == "http-01" {
s.http = &http.Server{ s.http = &http.Server{

View File

@@ -11,6 +11,26 @@ import (
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
) )
// APIError represents an error response from the management API.
type APIError struct {
StatusCode int
Message string
}
// Error implements the error interface.
func (e *APIError) Error() string {
return e.Message
}
// IsNotFound returns true if the error represents a 404 Not Found response.
func IsNotFound(err error) bool {
var apiErr *APIError
if ok := errors.As(err, &apiErr); ok {
return apiErr.StatusCode == http.StatusNotFound
}
return false
}
// Client Management service HTTP REST API Client // Client Management service HTTP REST API Client
type Client struct { type Client struct {
managementURL string managementURL string
@@ -105,6 +125,15 @@ type Client struct {
// Instance NetBird Instance API // Instance NetBird Instance API
// see more: https://docs.netbird.io/api/resources/instance // see more: https://docs.netbird.io/api/resources/instance
Instance *InstanceAPI Instance *InstanceAPI
// ReverseProxyServices NetBird reverse proxy services APIs
ReverseProxyServices *ReverseProxyServicesAPI
// ReverseProxyClusters NetBird reverse proxy clusters APIs
ReverseProxyClusters *ReverseProxyClustersAPI
// ReverseProxyDomains NetBird reverse proxy domains APIs
ReverseProxyDomains *ReverseProxyDomainsAPI
} }
// New initialize new Client instance using PAT token // New initialize new Client instance using PAT token
@@ -160,6 +189,9 @@ func (c *Client) initialize() {
c.IdentityProviders = &IdentityProvidersAPI{c} c.IdentityProviders = &IdentityProvidersAPI{c}
c.Ingress = &IngressAPI{c} c.Ingress = &IngressAPI{c}
c.Instance = &InstanceAPI{c} c.Instance = &InstanceAPI{c}
c.ReverseProxyServices = &ReverseProxyServicesAPI{c}
c.ReverseProxyClusters = &ReverseProxyClustersAPI{c}
c.ReverseProxyDomains = &ReverseProxyDomainsAPI{c}
} }
// NewRequest creates and executes new management API request // NewRequest creates and executes new management API request
@@ -194,10 +226,12 @@ func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Re
if resp.StatusCode > 299 { if resp.StatusCode > 299 {
parsedErr, pErr := parseResponse[util.ErrorResponse](resp) parsedErr, pErr := parseResponse[util.ErrorResponse](resp)
if pErr != nil { if pErr != nil {
return nil, pErr return nil, pErr
} }
return nil, errors.New(parsedErr.Message) return nil, &APIError{
StatusCode: resp.StatusCode,
Message: parsedErr.Message,
}
} }
return resp, nil return resp, nil

View File

@@ -0,0 +1,25 @@
package rest
import (
"context"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// ReverseProxyClustersAPI APIs for Reverse Proxy Clusters, do not use directly
type ReverseProxyClustersAPI struct {
c *Client
}
// List lists all available proxy clusters
func (a *ReverseProxyClustersAPI) List(ctx context.Context) ([]api.ProxyCluster, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/clusters", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.ProxyCluster](resp)
return ret, err
}

View File

@@ -0,0 +1,72 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"net/url"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// ReverseProxyDomainsAPI APIs for Reverse Proxy Domains, do not use directly
type ReverseProxyDomainsAPI struct {
c *Client
}
// List lists all reverse proxy domains
func (a *ReverseProxyDomainsAPI) List(ctx context.Context) ([]api.ReverseProxyDomain, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/domains", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.ReverseProxyDomain](resp)
return ret, err
}
// Create creates a new custom domain
func (a *ReverseProxyDomainsAPI) Create(ctx context.Context, request api.PostApiReverseProxiesDomainsJSONRequestBody) (*api.ReverseProxyDomain, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/reverse-proxies/domains", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.ReverseProxyDomain](resp)
if err != nil {
return nil, err
}
return &ret, nil
}
// Delete deletes a custom domain
func (a *ReverseProxyDomainsAPI) Delete(ctx context.Context, domainID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/domains/"+url.PathEscape(domainID), nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// Validate triggers domain ownership validation for a custom domain
func (a *ReverseProxyDomainsAPI) Validate(ctx context.Context, domainID string) error {
resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/domains/"+url.PathEscape(domainID)+"/validate", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,97 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"net/url"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// ReverseProxyServicesAPI APIs for Reverse Proxy Services, do not use directly
type ReverseProxyServicesAPI struct {
c *Client
}
// List lists all reverse proxy services
func (a *ReverseProxyServicesAPI) List(ctx context.Context) ([]api.Service, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/services", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Service](resp)
return ret, err
}
// Get retrieves a reverse proxy service by ID
func (a *ReverseProxyServicesAPI) Get(ctx context.Context, serviceID string) (*api.Service, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/services/"+url.PathEscape(serviceID), nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Service](resp)
if err != nil {
return nil, err
}
return &ret, nil
}
// Create creates a new reverse proxy service
func (a *ReverseProxyServicesAPI) Create(ctx context.Context, request api.PostApiReverseProxiesServicesJSONRequestBody) (*api.Service, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/reverse-proxies/services", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Service](resp)
if err != nil {
return nil, err
}
return &ret, nil
}
// Update updates a reverse proxy service
func (a *ReverseProxyServicesAPI) Update(ctx context.Context, serviceID string, request api.PutApiReverseProxiesServicesServiceIdJSONRequestBody) (*api.Service, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/reverse-proxies/services/"+url.PathEscape(serviceID), bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Service](resp)
if err != nil {
return nil, err
}
return &ret, nil
}
// Delete deletes a reverse proxy service
func (a *ReverseProxyServicesAPI) Delete(ctx context.Context, serviceID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/services/"+url.PathEscape(serviceID), nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}