Compare commits

...

54 Commits

Author SHA1 Message Date
pascal
373f014aea Merge branch 'feature/check-cert-locker-before-acme' into test/proxy-fixes 2026-03-09 17:24:31 +01:00
pascal
2df3fb959b Merge branch 'feature/domain-activity-events' into test/proxy-fixes 2026-03-09 17:23:54 +01:00
pascal
14b0e9462b Merge branch 'feature/move-proxy-to-opentelemetry' into test/proxy-fixes 2026-03-09 17:23:42 +01:00
pascal
8db71b545e add certificate issue duration metrics 2026-03-09 16:59:33 +01:00
pascal
7c3532d8e5 fix test 2026-03-09 16:18:49 +01:00
pascal
1aa1eef2c5 account for streaming 2026-03-09 16:08:30 +01:00
pascal
d9418ddc1e do log throughput and requests. Also add throughput to the log entries 2026-03-09 15:39:29 +01:00
pascal
1b4c831976 add activity events for domains 2026-03-09 14:18:05 +01:00
pascal
a19611d8e0 check the cert locker before starting the acme challenge 2026-03-09 13:40:41 +01:00
pascal
9ab6138040 fix mapping counting and metrics registry 2026-03-09 13:17:50 +01:00
Pascal Fischer
30c02ab78c [management] use the cache for the pkce state (#5516) 2026-03-09 12:23:06 +01:00
Zoltan Papp
3acd86e346 [client] "reset connection" error on wake from sleep (#5522)
Capture engine reference before actCancel() in cleanupConnection().

After actCancel(), the connectWithRetryRuns goroutine sets engine to nil,
causing connectClient.Stop() to skip shutdown. This allows the goroutine
to set ErrResetConnection on the shared state after Down() clears it,
causing the next Up() to fail.
2026-03-09 10:25:51 +01:00
pascal
c2fec57c0f switch proxy to use opentelemetry 2026-03-07 11:16:40 +01:00
Pascal Fischer
5c20f13c48 [management] fix domain uniqueness (#5529) 2026-03-07 10:46:37 +01:00
Pascal Fischer
e6587b071d [management] use realip for proxy registration (#5525) 2026-03-06 16:11:44 +01:00
Maycon Santos
85451ab4cd [management] Add stable domain resolution for combined server (#5515)
The combined server was using the hostname from exposedAddress for both
singleAccountModeDomain and dnsDomain, causing fresh installs to get
the wrong domain and existing installs to break if the config changed.
 Add resolveDomains() to BaseServer that reads domain from the store:
  - Fresh install (0 accounts): uses "netbird.selfhosted" default
  - Existing install: reads persisted domain from the account in DB
  - Store errors: falls back to default safely

The combined server opts in via AutoResolveDomains flag, while the
 standalone management server is unaffected.
2026-03-06 08:43:46 +01:00
Pascal Fischer
a7f3ba03eb [management] aggregate grpc metrics by accountID (#5486) 2026-03-05 22:10:45 +01:00
Maycon Santos
4f0a3a77ad [management] Avoid breaking single acc mode when switching domains (#5511)
* **Bug Fixes**
  * Fixed domain configuration handling in single account mode to properly retrieve and apply domain settings from account data.
  * Improved error handling when account data is unavailable with fallback to configured default domain.

* **Tests**
  * Added comprehensive test coverage for single account mode domain configuration scenarios, including edge cases for missing or unavailable account data.
2026-03-05 14:30:31 +01:00
Maycon Santos
44655ca9b5 [misc] add PR title validation workflow (#5503) 2026-03-05 11:43:18 +01:00
Viktor Liu
e601278117 [management,proxy] Add per-target options to reverse proxy (#5501) 2026-03-05 10:03:26 +01:00
Maycon Santos
8e7b016be2 [management] Replace in-memory expose tracker with SQL-backed operations (#5494)
The expose tracker used sync.Map for in-memory TTL tracking of active expose sessions, which broke and lost all sessions on restart.

Replace with SQL-backed operations that reuse the existing meta_last_renewed_at column:

- Add store methods: RenewEphemeralService, GetExpiredEphemeralServices, CountEphemeralServicesByPeer, EphemeralServiceExists
- Move duplicate/limit checks inside a transaction with row-level locking (SELECT ... FOR UPDATE) to prevent concurrent bypass
- Reaper re-checks expiry under row lock to avoid deleting a just-renewed service and prevent duplicate event emission 
- Add composite index on (source, source_peer) for efficient queries
- Batch-limit and column-select the reaper query to avoid DB/GC spikes
- Filter out malformed rows with empty source_peer
2026-03-04 18:15:13 +01:00
Maycon Santos
9e01ea7aae [misc] Add ISSUE_TEMPLATE configuration file (#5500)
Add issue template config file  with support and troubleshooting links
2026-03-04 14:30:54 +01:00
hbzhost
cfc7ec8bb9 [client] Fix SSH JWT auth failure with Azure Entra ID iat backdating (#5471)
Increase DefaultJWTMaxTokenAge from 5 to 10 minutes to accommodate
identity providers like Azure Entra ID that backdate the iat claim
by up to 5 minutes, causing tokens to be immediately rejected.

Fixes #5449

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-04 14:11:14 +01:00
Misha Bragin
b3bbc0e5c6 Fix embedded IdP metrics to count local and generic OIDC users (#5498) 2026-03-04 12:34:11 +02:00
Pascal Fischer
d7c8e37ff4 [management] Store connected proxies in DB (#5472)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2026-03-03 18:39:46 +01:00
Zoltan Papp
05b66e73bc [client] Fix deadlock in route peer status watcher (#5489)
Wrap peerStateUpdate send in a nested select to prevent goroutine
blocking when the consumer has exited, which could fill the
subscription buffer and deadlock the Status mutex.
2026-03-03 13:50:46 +01:00
Jeremie Deray
01ceedac89 [client] Fix profile config directory permissions (#5457)
* fix user profile dir perm

* fix fileExists

* revert return var change

* fix anti-pattern
2026-03-03 13:48:51 +01:00
Misha Bragin
403babd433 [self-hosted] specify sql file location of auth, activity and main store (#5487) 2026-03-03 12:53:16 +02:00
Maycon Santos
47133031e5 [client] fix: client/Dockerfile to reduce vulnerabilities (#5217)
Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2026-03-03 08:44:08 +01:00
Pascal Fischer
82da606886 [management] Add explicit target delete on service removal (#5420) 2026-03-02 18:25:44 +01:00
Viktor Liu
bbe5ae2145 [client] Flush buffer immediately to support gprc (#5469) 2026-03-02 15:17:08 +01:00
Viktor Liu
0b21498b39 [client] Fix close of closed channel panic in ConnectClient retry loop (#5470) 2026-03-02 10:07:53 +01:00
Viktor Liu
0ca59535f1 [management] Add reverse proxy services REST client (#5454) 2026-02-28 13:04:58 +08:00
Misha Bragin
59c77d0658 [self-hosted] support embedded IDP postgres db (#5443)
* Add postgres config for embedded idp

Entire-Checkpoint: 9ace190c1067

* Rename idpStore to authStore

Entire-Checkpoint: 73a896c79614

* Fix review notes

Entire-Checkpoint: 6556783c0df3

* Don't accept pq port = 0

Entire-Checkpoint: 80d45e37782f

* Optimize configs

Entire-Checkpoint: 80d45e37782f

* Fix lint issues

Entire-Checkpoint: 3eec968003d1

* Fail fast on combined postgres config

Entire-Checkpoint: b17839d3d8c6

* Simplify management config method

Entire-Checkpoint: 0f083effa20e
2026-02-27 14:52:54 +01:00
shuuri-labs
333e045099 Lower socket auto-discovery log from Info to Debug (#5463)
The discovery message was printing on every CLI invocation, which is
noisy for users on distros using the systemd template.
2026-02-26 17:51:38 +01:00
Zoltan Papp
c2c4d9d336 [client] Fix Server mutex held across waitForUp in Up() (#5460)
Up() acquired s.mutex with a deferred unlock, then called waitForUp()
while still holding the lock. waitForUp() blocks for up to 50 seconds
waiting on clientRunningChan/clientGiveUpChan, starving all concurrent
gRPC calls that require the same mutex (Status, ListProfiles, etc.).

Replace the deferred unlock with explicit s.mutex.Unlock() on every
early-return path and immediately before waitForUp(), matching the
pattern already used by the clientRunning==true branch.
2026-02-26 16:47:02 +01:00
Bethuel Mmbaga
9a6a72e88e [management] Fix user update permission validation (#5441) 2026-02-24 22:47:41 +03:00
Bethuel Mmbaga
afe6d9fca4 [management] Prevent deletion of groups linked to flow groups (#5439) 2026-02-24 21:19:43 +03:00
shuuri-labs
ef82905526 [client] Add non default socket file discovery (#5425)
- Automatic Unix daemon address discovery: if the default socket is missing, the client can find and use a single available socket.
- Client startup now resolves daemon addresses more robustly while preserving non-Unix behavior.
2026-02-24 17:02:06 +01:00
Zoltan Papp
d18747e846 [client] Exclude Flow domain from caching to prevent TLS failures (#5433)
* Exclude Flow domain from caching to prevent TLS failures due to stale records.

* Fix test
2026-02-24 16:48:38 +01:00
Maycon Santos
f341d69314 [management] Add custom domain counts and service metrics to self-hosted metrics (#5414) 2026-02-24 15:21:14 +01:00
Maycon Santos
327142837c [management] Refactor expose feature: move business logic from gRPC to manager (#5435)
Consolidate all expose business logic (validation, permission checks, TTL tracking, reaping) into the manager layer, making the gRPC layer a pure transport adapter that only handles proto conversion and authentication.

- Add ExposeServiceRequest/ExposeServiceResponse domain types with validation in the reverseproxy package
- Move expose tracker (TTL tracking, reaping, per-peer limits) from gRPC server into manager/expose_tracker.go
- Internalize tracking in CreateServiceFromPeer, RenewServiceFromPeer, and new StopServiceFromPeer so callers don't manage tracker state
- Untrack ephemeral services in DeleteService/DeleteAllServices to keep tracker in sync when services are deleted via API
- Simplify gRPC expose handlers to parse, auth, convert, delegate
- Remove tracker methods from Manager interface (internal detail)
2026-02-24 15:09:30 +01:00
Zoltan Papp
f8c0321aee [client] Simplify DNS logging by removing domain list from log output (#5396) 2026-02-24 10:35:45 +01:00
Zoltan Papp
89115ff76a [client] skip UAPI listener in netstack mode (#5397)
In netstack (proxy) mode, the process lacks permission to create
/var/run/wireguard, making the UAPI listener unnecessary and causing
a misleading error log. Introduce NewUSPConfigurerNoUAPI and use it
for the netstack device to avoid attempting to open the UAPI socket
entirely. Also consolidate UAPI error logging to a single call site.
2026-02-24 10:35:23 +01:00
Maycon Santos
63c83aa8d2 [client,management] Feature/client service expose (#5411)
CLI: new expose command to publish a local port with flags for PIN, password, user groups, custom domain, name prefix and protocol (HTTP default).
Management/API: create/renew/stop expose sessions (streamed status), automatic naming/domain, TTL renewals, background expiration, new management RPCs and client methods.
UI/API: account settings now include peer_expose_enabled and peer_expose_groups; new activity codes for peer expose events.
2026-02-24 10:02:16 +01:00
Zoltan Papp
37f025c966 Fix a race condition where a concurrent user-issued Up or Down command (#5418)
could interleave with a sleep/wake event causing out-of-order state
transitions. The mutex now covers the full duration of each handler
including the status check, the Up/Down call, and the flag update.

Note: if Up or Down commands are triggered in parallel with sleep/wake
events, the overall ordering of up/down/sleep/wake operations is still
not guaranteed beyond what the mutex provides within the handler itself.
2026-02-24 10:00:33 +01:00
Zoltan Papp
4a54f0d670 [Client] Remove connection semaphore (#5419)
* [Client] Remove connection semaphore

Remove the semaphore and the initial random sleep time (300ms) from the connectivity logic to speed up the initial connection time.

Note: Implement limiter logic that can prioritize router peers and keep the fast connection option for the first few peers.

* Remove unused function
2026-02-23 20:58:53 +01:00
Zoltan Papp
98890a29e3 [client] fix busy-loop in network monitor routing socket on macOS/BSD (#5424)
* [client] fix busy-loop in network monitor routing socket on macOS/BSD

After system wakeup, the AF_ROUTE socket created by Go's unix.Socket()
is non-blocking, causing unix.Read to return EAGAIN immediately and spin
at 100% CPU filling the log with thousands of warnings per second.

Replace the tight read loop with a unix.Select call that blocks until
the fd is readable, checking ctx cancellation on each 1-second timeout.
Fatal errors (EBADF, EINVAL) now return an error instead of looping.

* [client] add fd range validation in waitReadable to prevent out-of-bound errors
2026-02-23 20:58:27 +01:00
Pascal Fischer
9d123ec059 [proxy] add pre-shared key support (#5377) 2026-02-23 16:31:29 +01:00
Pascal Fischer
5d171f181a [proxy] Send proxy updates on account delete (#5375) 2026-02-23 16:08:28 +01:00
Vlad
22f878b3b7 [management] network map components assembling (#5193) 2026-02-23 15:34:35 +01:00
Misha Bragin
44ef1a18dd [self-hosted] add Embedded IdP metrics (#5407) 2026-02-22 11:58:35 +02:00
Misha Bragin
2b98dc4e52 [self-hosted] Support activity store engine in the combined server (#5406) 2026-02-22 11:58:17 +02:00
Zoltan Papp
2a26cb4567 [client] stop upstream retry loop immediately on context cancellation (#5403)
stop upstream retry loop immediately on context cancellation
2026-02-20 14:44:14 +01:00
151 changed files with 17351 additions and 3812 deletions

14
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,14 @@
blank_issues_enabled: true
contact_links:
- name: Community Support
url: https://forum.netbird.io/
about: Community support forum
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact us for support
- name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
skip: go.mod,go.sum,**/proxy/web/** skip: go.mod,go.sum,**/proxy/web/**
golangci: golangci:
strategy: strategy:

51
.github/workflows/pr-title-check.yml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: PR Title Check
on:
pull_request:
types: [opened, edited, synchronize, reopened]
jobs:
check-title:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;
const allowedTags = [
'management',
'client',
'signal',
'proxy',
'relay',
'misc',
'infrastructure',
'self-hosted',
'doc',
];
const pattern = /^\[([^\]]+)\]\s+.+/;
const match = title.match(pattern);
if (!match) {
core.setFailed(
`PR title must start with a tag in brackets.\n` +
`Example: [client] fix something\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
const invalid = tags.filter(t => !allowedTags.includes(t));
if (invalid.length > 0) {
core.setFailed(
`Invalid tag(s): ${invalid.join(', ')}\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
console.log(`Valid PR title tags: [${tags.join(', ')}]`);

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.2 FROM alpine:3.23.3
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \ RUN apk add --no-cache \
bash \ bash \

194
client/cmd/expose.go Normal file
View File

@@ -0,0 +1,194 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
var (
exposePin string
exposePassword string
exposeUserGroups []string
exposeDomain string
exposeNamePrefix string
exposeProtocol string
)
var exposeCmd = &cobra.Command{
Use: "expose <port>",
Short: "Expose a local port via the NetBird reverse proxy",
Args: cobra.ExactArgs(1),
Example: "netbird expose --with-password safe-pass 8080",
RunE: exposeFn,
}
func init() {
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use, http/https is supported (e.g. --protocol http)")
}
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
port, err := strconv.ParseUint(portStr, 10, 32)
if err != nil {
return 0, fmt.Errorf("invalid port number: %s", portStr)
}
if port == 0 || port > 65535 {
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
}
if !isProtocolValid(exposeProtocol) {
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
}
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
}
if cmd.Flags().Changed("with-password") && exposePassword == "" {
return 0, fmt.Errorf("password cannot be empty")
}
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
return 0, fmt.Errorf("user groups cannot be empty")
}
return port, nil
}
func isProtocolValid(exposeProtocol string) bool {
return strings.ToLower(exposeProtocol) == "http" || strings.ToLower(exposeProtocol) == "https"
}
func exposeFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
log.Errorf("failed initializing log %v", err)
return err
}
cmd.Root().SilenceUsage = false
port, err := validateExposeFlags(cmd, args[0])
if err != nil {
return err
}
cmd.Root().SilenceUsage = true
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close daemon connection: %v", err)
}
}()
client := proto.NewDaemonServiceClient(conn)
protocol, err := toExposeProtocol(exposeProtocol)
if err != nil {
return err
}
stream, err := client.ExposeService(ctx, &proto.ExposeServiceRequest{
Port: uint32(port),
Protocol: protocol,
Pin: exposePin,
Password: exposePassword,
UserGroups: exposeUserGroups,
Domain: exposeDomain,
NamePrefix: exposeNamePrefix,
})
if err != nil {
return fmt.Errorf("expose service: %w", err)
}
if err := handleExposeReady(cmd, stream, port); err != nil {
return err
}
return waitForExposeEvents(cmd, ctx, stream)
}
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
switch strings.ToLower(exposeProtocol) {
case "http":
return proto.ExposeProtocol_EXPOSE_HTTP, nil
case "https":
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
default:
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
}
}
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %w", err)
}
switch e := event.Event.(type) {
case *proto.ExposeServiceEvent_Ready:
cmd.Println("Service exposed successfully!")
cmd.Printf(" Name: %s\n", e.Ready.ServiceName)
cmd.Printf(" URL: %s\n", e.Ready.ServiceUrl)
cmd.Printf(" Domain: %s\n", e.Ready.Domain)
cmd.Printf(" Protocol: %s\n", exposeProtocol)
cmd.Printf(" Port: %d\n", port)
cmd.Println()
cmd.Println("Press Ctrl+C to stop exposing.")
return nil
default:
return fmt.Errorf("unexpected expose event: %T", event.Event)
}
}
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
for {
_, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.Println("\nService stopped.")
//nolint:nilerr
return nil
}
if errors.Is(err, io.EOF) {
return fmt.Errorf("connection to daemon closed unexpectedly")
}
return fmt.Errorf("stream error: %w", err)
}
}
}

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
) )
@@ -80,6 +81,15 @@ var (
Short: "", Short: "",
Long: "", Long: "",
SilenceUsage: true, SilenceUsage: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(cmd.Root())
// Don't resolve for service commands — they create the socket, not connect to it.
if !isServiceCmd(cmd) {
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
}
return nil
},
} }
) )
@@ -144,6 +154,7 @@ func init() {
rootCmd.AddCommand(forwardingRulesCmd) rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd) rootCmd.AddCommand(debugCmd)
rootCmd.AddCommand(profileCmd) rootCmd.AddCommand(profileCmd)
rootCmd.AddCommand(exposeCmd)
networksCMD.AddCommand(routesListCmd) networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -385,7 +396,6 @@ func migrateToNetbird(oldPath, newPath string) bool {
} }
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) { func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
@@ -398,3 +408,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
return conn, nil return conn, nil
} }
// isServiceCmd returns true if cmd is the "service" command or a child of it.
func isServiceCmd(cmd *cobra.Command) bool {
for c := cmd; c != nil; c = c.Parent() {
if c.Name() == "service" {
return true
}
}
return false
}

View File

@@ -5,20 +5,18 @@ package configurer
import ( import (
"net" "net"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
) )
func openUAPI(deviceName string) (net.Listener, error) { func openUAPI(deviceName string) (net.Listener, error) {
uapiSock, err := ipc.UAPIOpen(deviceName) uapiSock, err := ipc.UAPIOpen(deviceName)
if err != nil { if err != nil {
log.Errorf("failed to open uapi socket: %v", err)
return nil, err return nil, err
} }
listener, err := ipc.UAPIListen(deviceName, uapiSock) listener, err := ipc.UAPIListen(deviceName, uapiSock)
if err != nil { if err != nil {
log.Errorf("failed to listen on uapi socket: %v", err) _ = uapiSock.Close()
return nil, err return nil, err
} }

View File

@@ -54,6 +54,14 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
return wgCfg return wgCfg
} }
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
return &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
}
}
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error { func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key") log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey) key, err := wgtypes.ParseKey(privateKey)

View File

@@ -79,7 +79,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder()) t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
if cErr := tunIface.Close(); cErr != nil { if cErr := tunIface.Close(); cErr != nil {

View File

@@ -331,8 +331,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil { if runningChan != nil {
select {
case <-runningChan:
default:
close(runningChan) close(runningChan)
runningChan = nil }
} }
<-engineCtx.Done() <-engineCtx.Done()

View File

@@ -0,0 +1,60 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
var scanDir = "/var/run/netbird"
// setScanDir overrides the scan directory (used by tests).
func setScanDir(dir string) {
scanDir = dir
}
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
// mismatch between the netbird@.service template (which places the socket under
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
func ResolveUnixDaemonAddr(addr string) string {
if !strings.HasPrefix(addr, "unix://") {
return addr
}
sockPath := strings.TrimPrefix(addr, "unix://")
if _, err := os.Stat(sockPath); err == nil {
return addr
}
entries, err := os.ReadDir(scanDir)
if err != nil {
return addr
}
var found []string
for _, e := range entries {
if e.IsDir() {
continue
}
if strings.HasSuffix(e.Name(), ".sock") {
found = append(found, filepath.Join(scanDir, e.Name()))
}
}
switch len(found) {
case 1:
resolved := "unix://" + found[0]
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
return resolved
case 0:
return addr
default:
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
return addr
}
}

View File

@@ -0,0 +1,8 @@
//go:build windows || ios || android
package daemonaddr
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
func ResolveUnixDaemonAddr(addr string) string {
return addr
}

View File

@@ -0,0 +1,121 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"testing"
)
// createSockFile creates a regular file with a .sock extension.
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
// sufficient and avoids Unix socket path-length limits on macOS.
func createSockFile(t *testing.T, path string) {
t.Helper()
if err := os.WriteFile(path, nil, 0o600); err != nil {
t.Fatalf("failed to create test sock file at %s: %v", path, err)
}
}
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
tmp := t.TempDir()
sock := filepath.Join(tmp, "netbird.sock")
createSockFile(t, sock)
addr := "unix://" + sock
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
tmp := t.TempDir()
// Default socket does not exist
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
// Create a scan dir with one socket
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
instanceSock := filepath.Join(sd, "main.sock")
createSockFile(t, instanceSock)
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
expected := "unix://" + instanceSock
if got != expected {
t.Errorf("expected %s, got %s", expected, got)
}
}
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
createSockFile(t, filepath.Join(sd, "main.sock"))
createSockFile(t, filepath.Join(sd, "other.sock"))
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
addr := "tcp://127.0.0.1:41731"
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
origScanDir := scanDir
setScanDir(filepath.Join(tmp, "nonexistent"))
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}

View File

@@ -277,7 +277,7 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
} }
} }
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains) log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
return ruleIndex, nil return ruleIndex, nil
} }

View File

@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
} }
} }
if serverDomains.Flow != "" { // Flow receiver domain is intentionally excluded from caching.
domains = append(domains, serverDomains.Flow) // Cloud providers may rotate the IP behind this domain; a stale cached record
} // causes TLS certificate verification failures on reconnect.
for _, stun := range serverDomains.Stuns { for _, stun := range serverDomains.Stuns {
if stun != "" { if stun != "" {

View File

@@ -391,7 +391,8 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
} }
assert.Len(t, resolver.GetCachedDomains(), 3) assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing) // Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
partialDomains := dnsconfig.ServerDomains{ partialDomains := dnsconfig.ServerDomains{
Flow: "github.com", Flow: "github.com",
} }
@@ -400,10 +401,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
t.Skipf("Skipping test due to DNS resolution failure: %v", err) t.Skipf("Skipping test due to DNS resolution failure: %v", err)
} }
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type") assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
finalDomains := resolver.GetCachedDomains() finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain") assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
domainStrings := make([]string, len(finalDomains)) domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains { for i, d := range finalDomains {
@@ -412,5 +413,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
assert.Contains(t, domainStrings, "example.org") assert.Contains(t, domainStrings, "example.org")
assert.Contains(t, domainStrings, "google.com") assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com") assert.Contains(t, domainStrings, "cloudflare.com")
assert.Contains(t, domainStrings, "github.com") assert.NotContains(t, domainStrings, "github.com")
} }

View File

@@ -351,9 +351,13 @@ func (u *upstreamResolverBase) waitUntilResponse() {
return fmt.Errorf("upstream check call error") return fmt.Errorf("upstream check call error")
} }
err := backoff.Retry(operation, exponentialBackOff) err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
if err != nil { if err != nil {
log.Warn(err) if errors.Is(err, context.Canceled) {
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
} else {
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
}
return return
} }

View File

@@ -36,6 +36,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
@@ -53,13 +54,11 @@ import (
"github.com/netbirdio/netbird/client/internal/updatemanager" "github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/jobexec" "github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto" cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client" mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -75,7 +74,6 @@ import (
const ( const (
PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
disableAutoUpdate = "disabled" disableAutoUpdate = "disabled"
) )
@@ -208,7 +206,6 @@ type Engine struct {
syncRespMux sync.RWMutex syncRespMux sync.RWMutex
persistSyncResponse bool persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager flowManager nftypes.FlowManager
// auto-update // auto-update
@@ -224,6 +221,8 @@ type Engine struct {
jobExecutor *jobexec.Executor jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup jobExecutorWG sync.WaitGroup
exposeManager *expose.Manager
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -266,7 +265,6 @@ func NewEngine(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
stateManager: stateManager, stateManager: stateManager,
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(), jobExecutor: jobexec.NewExecutor(),
} }
@@ -419,6 +417,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.cancel() e.cancel()
} }
e.ctx, e.cancel = context.WithCancel(e.clientCtx) e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
wgIface, err := e.newWgIface() wgIface, err := e.newWgIface()
if err != nil { if err != nil {
@@ -801,7 +800,7 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
disabled := autoUpdateSettings.Version == disableAutoUpdate disabled := autoUpdateSettings.Version == disableAutoUpdate
// Stop and cleanup if disabled // stop and cleanup if disabled
if e.updateManager != nil && disabled { if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager") log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop() e.updateManager.Stop()
@@ -1539,7 +1538,6 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
IFaceDiscover: e.mobileDep.IFaceDiscover, IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager, RelayManager: e.relayManager,
SrWatcher: e.srWatcher, SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
} }
peerConn, err := peer.NewConn(config, serviceDependencies) peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil { if err != nil {
@@ -1824,11 +1822,18 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager return e.routeManager
} }
// GetFirewallManager returns the firewall manager // GetFirewallManager returns the firewall manager.
func (e *Engine) GetFirewallManager() firewallManager.Manager { func (e *Engine) GetFirewallManager() firewallManager.Manager {
return e.firewall return e.firewall
} }
// GetExposeManager returns the expose session manager.
func (e *Engine) GetExposeManager() *expose.Manager {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return e.exposeManager
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) { func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName) iface, err := net.InterfaceByName(ifaceName)
if err != nil { if err != nil {

View File

@@ -0,0 +1,95 @@
package expose
import (
"context"
"time"
mgm "github.com/netbirdio/netbird/shared/management/client"
log "github.com/sirupsen/logrus"
)
const renewTimeout = 10 * time.Second
// Response holds the response from exposing a service.
type Response struct {
ServiceName string
ServiceURL string
Domain string
}
type Request struct {
NamePrefix string
Domain string
Port uint16
Protocol int
Pin string
Password string
UserGroups []string
}
type ManagementClient interface {
CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error)
RenewExpose(ctx context.Context, domain string) error
StopExpose(ctx context.Context, domain string) error
}
// Manager handles expose session lifecycle via the management client.
type Manager struct {
mgmClient ManagementClient
ctx context.Context
}
// NewManager creates a new expose Manager using the given management client.
func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager {
return &Manager{mgmClient: mgmClient, ctx: ctx}
}
// Expose creates a new expose session via the management server.
func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
log.Infof("exposing service on port %d", req.Port)
resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req))
if err != nil {
return nil, err
}
log.Infof("expose session created for %s", resp.Domain)
return fromClientExposeResponse(resp), nil
}
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
defer m.stop(domain)
for {
select {
case <-ctx.Done():
log.Infof("context canceled, stopping keep alive for %s", domain)
return nil
case <-ticker.C:
if err := m.renew(ctx, domain); err != nil {
log.Errorf("renewing expose session for %s: %v", domain, err)
return err
}
}
}
}
// renew extends the TTL of an active expose session.
func (m *Manager) renew(ctx context.Context, domain string) error {
renewCtx, cancel := context.WithTimeout(ctx, renewTimeout)
defer cancel()
return m.mgmClient.RenewExpose(renewCtx, domain)
}
// stop terminates an active expose session.
func (m *Manager) stop(domain string) {
stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout)
defer cancel()
err := m.mgmClient.StopExpose(stopCtx, domain)
if err != nil {
log.Warnf("Failed stopping expose session for %s: %v", domain, err)
}
}

View File

@@ -0,0 +1,95 @@
package expose
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
daemonProto "github.com/netbirdio/netbird/client/proto"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
func TestManager_Expose_Success(t *testing.T) {
mock := &mgm.MockClient{
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
return &mgm.ExposeResponse{
ServiceName: "my-service",
ServiceURL: "https://my-service.example.com",
Domain: "my-service.example.com",
}, nil
},
}
m := NewManager(context.Background(), mock)
result, err := m.Expose(context.Background(), Request{Port: 8080})
require.NoError(t, err)
assert.Equal(t, "my-service", result.ServiceName, "service name should match")
assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match")
assert.Equal(t, "my-service.example.com", result.Domain, "domain should match")
}
func TestManager_Expose_Error(t *testing.T) {
mock := &mgm.MockClient{
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
return nil, errors.New("permission denied")
},
}
m := NewManager(context.Background(), mock)
_, err := m.Expose(context.Background(), Request{Port: 8080})
require.Error(t, err)
assert.Contains(t, err.Error(), "permission denied", "error should propagate")
}
func TestManager_Renew_Success(t *testing.T) {
mock := &mgm.MockClient{
RenewExposeFunc: func(ctx context.Context, domain string) error {
assert.Equal(t, "my-service.example.com", domain, "domain should be passed through")
return nil
},
}
m := NewManager(context.Background(), mock)
err := m.renew(context.Background(), "my-service.example.com")
require.NoError(t, err)
}
func TestManager_Renew_Timeout(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
mock := &mgm.MockClient{
RenewExposeFunc: func(ctx context.Context, domain string) error {
return ctx.Err()
},
}
m := NewManager(ctx, mock)
err := m.renew(ctx, "my-service.example.com")
require.Error(t, err)
}
func TestNewRequest(t *testing.T) {
req := &daemonProto.ExposeServiceRequest{
Port: 8080,
Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS,
Pin: "123456",
Password: "secret",
UserGroups: []string{"group1", "group2"},
Domain: "custom.example.com",
NamePrefix: "my-prefix",
}
exposeReq := NewRequest(req)
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
assert.Equal(t, "secret", exposeReq.Password, "password should match")
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match")
assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match")
}

View File

@@ -0,0 +1,39 @@
package expose
import (
daemonProto "github.com/netbirdio/netbird/client/proto"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest.
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
return &Request{
Port: uint16(req.Port),
Protocol: int(req.Protocol),
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
Domain: req.Domain,
NamePrefix: req.NamePrefix,
}
}
func toClientExposeRequest(req Request) mgm.ExposeRequest {
return mgm.ExposeRequest{
NamePrefix: req.NamePrefix,
Domain: req.Domain,
Port: req.Port,
Protocol: req.Protocol,
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
}
}
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
return &Response{
ServiceName: response.ServiceName,
Domain: response.Domain,
ServiceURL: response.ServiceURL,
}
}

View File

@@ -22,18 +22,24 @@ func prepareFd() (int, error) {
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error { func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
for { for {
select { // Wait until fd is readable or context is cancelled, to avoid a busy-loop
case <-ctx.Done(): // when the routing socket returns EAGAIN (e.g. immediately after wakeup).
return ctx.Err() if err := waitReadable(ctx, fd); err != nil {
default: return err
}
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, err := unix.Read(fd, buf) n, err := unix.Read(fd, buf)
if err != nil { if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue continue
} }
if errors.Is(err, unix.EBADF) || errors.Is(err, unix.EINVAL) {
return fmt.Errorf("routing socket closed: %w", err)
}
return fmt.Errorf("read routing socket: %w", err)
}
if n < unix.SizeofRtMsghdr { if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue continue
@@ -70,7 +76,6 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
} }
} }
} }
}
} }
func parseRouteMessage(buf []byte) (*systemops.Route, error) { func parseRouteMessage(buf []byte) (*systemops.Route, error) {
@@ -90,3 +95,33 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
return systemops.MsgToRoute(msg) return systemops.MsgToRoute(msg)
} }
// waitReadable blocks until fd has data to read, or ctx is cancelled.
func waitReadable(ctx context.Context, fd int) error {
var fdset unix.FdSet
if fd < 0 || fd/unix.NFDBITS >= len(fdset.Bits) {
return fmt.Errorf("fd %d out of range for FdSet", fd)
}
for {
if err := ctx.Err(); err != nil {
return err
}
fdset = unix.FdSet{}
fdset.Set(fd)
// Use a 1-second timeout so we can re-check ctx periodically.
tv := unix.Timeval{Sec: 1}
n, err := unix.Select(fd+1, &fdset, nil, nil, &tv)
if err != nil {
if errors.Is(err, unix.EINTR) {
continue
}
return fmt.Errorf("select on routing socket: %w", err)
}
if n > 0 {
return nil
}
// timeout — loop back and re-check ctx
}
}

View File

@@ -3,7 +3,6 @@ package peer
import ( import (
"context" "context"
"fmt" "fmt"
"math/rand"
"net" "net"
"net/netip" "net/netip"
"runtime" "runtime"
@@ -25,7 +24,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
type ServiceDependencies struct { type ServiceDependencies struct {
@@ -34,7 +32,6 @@ type ServiceDependencies struct {
IFaceDiscover stdnet.ExternalIFaceDiscover IFaceDiscover stdnet.ExternalIFaceDiscover
RelayManager *relayClient.Manager RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher SrWatcher *guard.SRWatcher
Semaphore *semaphoregroup.SemaphoreGroup
PeerConnDispatcher *dispatcher.ConnectionDispatcher PeerConnDispatcher *dispatcher.ConnectionDispatcher
} }
@@ -112,7 +109,6 @@ type Conn struct {
handshaker *Handshaker handshaker *Handshaker
guard *guard.Guard guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
wg sync.WaitGroup wg sync.WaitGroup
// debug purpose // debug purpose
@@ -139,7 +135,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
iFaceDiscover: services.IFaceDiscover, iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager, relayManager: services.RelayManager,
srWatcher: services.SrWatcher, srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(), statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(), statusICE: worker.NewAtomicStatus(),
dumpState: dumpState, dumpState: dumpState,
@@ -154,15 +149,10 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used. // be used.
func (conn *Conn) Open(engineCtx context.Context) error { func (conn *Conn) Open(engineCtx context.Context) error {
if err := conn.semaphore.Add(engineCtx); err != nil {
return err
}
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
if conn.opened { if conn.opened {
conn.semaphore.Done()
return nil return nil
} }
@@ -173,7 +163,6 @@ func (conn *Conn) Open(engineCtx context.Context) error {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil { if err != nil {
conn.semaphore.Done()
return err return err
} }
conn.workerICE = workerICE conn.workerICE = workerICE
@@ -207,10 +196,6 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.wg.Add(1) conn.wg.Add(1)
go func() { go func() {
defer conn.wg.Done() defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent) conn.guard.Start(conn.ctx, conn.onGuardEvent)
}() }()
conn.opened = true conn.opened = true
@@ -670,19 +655,6 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
} }
} }
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
maxWait := 300
duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond
timeout := time.NewTimer(duration)
defer timeout.Stop()
select {
case <-ctx.Done():
case <-timeout.C:
}
}
func (conn *Conn) isRelayed() bool { func (conn *Conn) isRelayed() bool {
switch conn.currentConnPriority { switch conn.currentConnPriority {
case conntype.Relay, conntype.ICETurn: case conntype.Relay, conntype.ICETurn:

View File

@@ -15,7 +15,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
var testDispatcher = dispatcher.NewConnectionDispatcher() var testDispatcher = dispatcher.NewConnectionDispatcher()
@@ -53,7 +52,6 @@ func TestConn_GetKey(t *testing.T) {
sd := ServiceDependencies{ sd := ServiceDependencies{
SrWatcher: swWatcher, SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher, PeerConnDispatcher: testDispatcher,
} }
conn, err := NewConn(connConf, sd) conn, err := NewConn(connConf, sd)
@@ -71,7 +69,6 @@ func TestConn_OnRemoteOffer(t *testing.T) {
sd := ServiceDependencies{ sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"), StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher, SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher, PeerConnDispatcher: testDispatcher,
} }
conn, err := NewConn(connConf, sd) conn, err := NewConn(connConf, sd)
@@ -110,7 +107,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
sd := ServiceDependencies{ sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"), StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher, SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher, PeerConnDispatcher: testDispatcher,
} }
conn, err := NewConn(connConf, sd) conn, err := NewConn(connConf, sd)

View File

@@ -198,7 +198,7 @@ func getConfigDirForUser(username string) (string, error) {
configDir := filepath.Join(DefaultConfigPathDir, username) configDir := filepath.Join(DefaultConfigPathDir, username)
if _, err := os.Stat(configDir); os.IsNotExist(err) { if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0600); err != nil { if err := os.MkdirAll(configDir, 0700); err != nil {
return "", err return "", err
} }
} }
@@ -206,9 +206,15 @@ func getConfigDirForUser(username string) (string, error) {
return configDir, nil return configDir, nil
} }
func fileExists(path string) bool { func fileExists(path string) (bool, error) {
_, err := os.Stat(path) _, err := os.Stat(path)
return !os.IsNotExist(err) if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
} }
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -635,7 +641,11 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
// UpdateConfig update existing configuration according to input configuration and return with the configuration // UpdateConfig update existing configuration according to input configuration and return with the configuration
func UpdateConfig(input ConfigInput) (*Config, error) { func UpdateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) { configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath) return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
} }
@@ -644,7 +654,11 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
// UpdateOrCreateConfig reads existing config or generates a new one // UpdateOrCreateConfig reads existing config or generates a new one
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) { configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
log.Infof("generating new config %s", input.ConfigPath) log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input) cfg, err := createNewConfig(input)
if err != nil { if err != nil {
@@ -657,7 +671,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if isPreSharedKeyHidden(input.PreSharedKey) { if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil input.PreSharedKey = nil
} }
err := util.EnforcePermission(input.ConfigPath) err = util.EnforcePermission(input.ConfigPath)
if err != nil { if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err) log.Errorf("failed to enforce permission on config dir: %v", err)
} }
@@ -784,7 +798,12 @@ func ReadConfig(configPath string) (*Config, error) {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func readConfig(configPath string, createIfMissing bool) (*Config, error) { func readConfig(configPath string, createIfMissing bool) (*Config, error) {
if fileExists(configPath) { configExists, err := fileExists(configPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if configExists {
err := util.EnforcePermission(configPath) err := util.EnforcePermission(configPath)
if err != nil { if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err) log.Errorf("failed to enforce permission on config dir: %v", err)
@@ -831,7 +850,11 @@ func DirectWriteOutConfig(path string, config *Config) error {
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes. // DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox). // Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) { func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) { configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
log.Infof("generating new config %s", input.ConfigPath) log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input) cfg, err := createNewConfig(input)
if err != nil { if err != nil {

View File

@@ -256,7 +256,11 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
} }
profPath := filepath.Join(configDir, profileName+".json") profPath := filepath.Join(configDir, profileName+".json")
if fileExists(profPath) { profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
return ErrProfileAlreadyExists return ErrProfileAlreadyExists
} }
@@ -285,7 +289,11 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName) return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
} }
profPath := filepath.Join(configDir, profileName+".json") profPath := filepath.Join(configDir, profileName+".json")
if !fileExists(profPath) { profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if !profileExists {
return ErrProfileNotFound return ErrProfileNotFound
} }

View File

@@ -20,7 +20,11 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er
} }
stateFile := filepath.Join(configDir, profileName+".state.json") stateFile := filepath.Join(configDir, profileName+".state.json")
if !fileExists(stateFile) { stateFileExists, err := fileExists(stateFile)
if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
}
if !stateFileExists {
return nil, errors.New("profile state file does not exist") return nil, errors.New("profile state file does not exist")
} }

View File

@@ -263,8 +263,14 @@ func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, pe
case <-closer: case <-closer:
return return
case routerStates := <-subscription.Events(): case routerStates := <-subscription.Events():
peerStateUpdate <- routerStates select {
case peerStateUpdate <- routerStates:
log.Debugf("triggered route state update for Peer: %s", peerKey) log.Debugf("triggered route state update for Peer: %s", peerKey)
case <-ctx.Done():
return
case <-closer:
return
}
} }
} }
} }

View File

@@ -0,0 +1,80 @@
package handler
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
)
type Agent interface {
Up(ctx context.Context) error
Down(ctx context.Context) error
Status() (internal.StatusType, error)
}
type SleepHandler struct {
agent Agent
mu sync.Mutex
// sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake
sleepTriggeredDown bool
}
func New(agent Agent) *SleepHandler {
return &SleepHandler{
agent: agent,
}
}
func (s *SleepHandler) HandleWakeUp(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.sleepTriggeredDown {
log.Info("skipping up because wasn't sleep down")
return nil
}
// avoid other wakeup runs if sleep didn't make the computer sleep
s.sleepTriggeredDown = false
log.Info("running up after wake up")
err := s.agent.Up(ctx)
if err != nil {
log.Errorf("running up failed: %v", err)
return err
}
log.Info("running up command executed successfully")
return nil
}
func (s *SleepHandler) HandleSleep(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
status, err := s.agent.Status()
if err != nil {
return err
}
if status != internal.StatusConnecting && status != internal.StatusConnected {
log.Infof("skipping setting the agent down because status is %s", status)
return nil
}
log.Info("running down after system started sleeping")
if err = s.agent.Down(ctx); err != nil {
log.Errorf("running down failed: %v", err)
return err
}
s.sleepTriggeredDown = true
log.Info("running down executed successfully")
return nil
}

View File

@@ -0,0 +1,153 @@
package handler
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockAgent struct {
upErr error
downErr error
statusErr error
status internal.StatusType
upCalls int
}
func (m *mockAgent) Up(_ context.Context) error {
m.upCalls++
return m.upErr
}
func (m *mockAgent) Down(_ context.Context) error {
return m.downErr
}
func (m *mockAgent) Status() (internal.StatusType, error) {
return m.status, m.statusErr
}
func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) {
agent := &mockAgent{status: status}
return New(agent), agent
}
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false")
}
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
h, _ := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
// Even if Up fails, flag should be reset
_ = h.HandleWakeUp(context.Background())
assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up")
}
func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, agent.upCalls)
assert.False(t, h.sleepTriggeredDown)
}
func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
agent.upErr = errors.New("up failed")
err := h.HandleWakeUp(context.Background())
assert.ErrorIs(t, err, agent.upErr)
assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails")
}
func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
_ = h.HandleWakeUp(context.Background())
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op")
}
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Idle", internal.StatusIdle},
{"NeedsLogin", internal.StatusNeedsLogin},
{"LoginFailed", internal.StatusLoginFailed},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, _ := newHandler(tt.status)
err := h.HandleSleep(context.Background())
require.NoError(t, err)
assert.False(t, h.sleepTriggeredDown)
})
}
}
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Connecting", internal.StatusConnecting},
{"Connected", internal.StatusConnected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, _ := newHandler(tt.status)
err := h.HandleSleep(context.Background())
require.NoError(t, err)
assert.True(t, h.sleepTriggeredDown)
})
}
}
func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) {
agent := &mockAgent{statusErr: errors.New("status error")}
h := New(agent)
err := h.HandleSleep(context.Background())
assert.ErrorIs(t, err, agent.statusErr)
assert.False(t, h.sleepTriggeredDown)
}
func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) {
agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")}
h := New(agent)
err := h.HandleSleep(context.Background())
assert.ErrorIs(t, err, agent.downErr)
assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails")
}

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.36.6 // protoc-gen-go v1.36.6
// protoc v6.32.1 // protoc v6.33.3
// source: daemon.proto // source: daemon.proto
package proto package proto
@@ -88,6 +88,58 @@ func (LogLevel) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{0} return file_daemon_proto_rawDescGZIP(), []int{0}
} }
type ExposeProtocol int32
const (
ExposeProtocol_EXPOSE_HTTP ExposeProtocol = 0
ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1
ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2
ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3
)
// Enum value maps for ExposeProtocol.
var (
ExposeProtocol_name = map[int32]string{
0: "EXPOSE_HTTP",
1: "EXPOSE_HTTPS",
2: "EXPOSE_TCP",
3: "EXPOSE_UDP",
}
ExposeProtocol_value = map[string]int32{
"EXPOSE_HTTP": 0,
"EXPOSE_HTTPS": 1,
"EXPOSE_TCP": 2,
"EXPOSE_UDP": 3,
}
)
func (x ExposeProtocol) Enum() *ExposeProtocol {
p := new(ExposeProtocol)
*p = x
return p
}
func (x ExposeProtocol) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (ExposeProtocol) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[1].Descriptor()
}
func (ExposeProtocol) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[1]
}
func (x ExposeProtocol) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use ExposeProtocol.Descriptor instead.
func (ExposeProtocol) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{1}
}
// avoid collision with loglevel enum // avoid collision with loglevel enum
type OSLifecycleRequest_CycleType int32 type OSLifecycleRequest_CycleType int32
@@ -122,11 +174,11 @@ func (x OSLifecycleRequest_CycleType) String() string {
} }
func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor { func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[1].Descriptor() return file_daemon_proto_enumTypes[2].Descriptor()
} }
func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType { func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[1] return &file_daemon_proto_enumTypes[2]
} }
func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber { func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber {
@@ -174,11 +226,11 @@ func (x SystemEvent_Severity) String() string {
} }
func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor { func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[2].Descriptor() return file_daemon_proto_enumTypes[3].Descriptor()
} }
func (SystemEvent_Severity) Type() protoreflect.EnumType { func (SystemEvent_Severity) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[2] return &file_daemon_proto_enumTypes[3]
} }
func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { func (x SystemEvent_Severity) Number() protoreflect.EnumNumber {
@@ -229,11 +281,11 @@ func (x SystemEvent_Category) String() string {
} }
func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor { func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[3].Descriptor() return file_daemon_proto_enumTypes[4].Descriptor()
} }
func (SystemEvent_Category) Type() protoreflect.EnumType { func (SystemEvent_Category) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[3] return &file_daemon_proto_enumTypes[4]
} }
func (x SystemEvent_Category) Number() protoreflect.EnumNumber { func (x SystemEvent_Category) Number() protoreflect.EnumNumber {
@@ -5600,6 +5652,224 @@ func (x *InstallerResultResponse) GetErrorMsg() string {
return "" return ""
} }
type ExposeServiceRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Port uint32 `protobuf:"varint,1,opt,name=port,proto3" json:"port,omitempty"`
Protocol ExposeProtocol `protobuf:"varint,2,opt,name=protocol,proto3,enum=daemon.ExposeProtocol" json:"protocol,omitempty"`
Pin string `protobuf:"bytes,3,opt,name=pin,proto3" json:"pin,omitempty"`
Password string `protobuf:"bytes,4,opt,name=password,proto3" json:"password,omitempty"`
UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"`
Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"`
NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceRequest) Reset() {
*x = ExposeServiceRequest{}
mi := &file_daemon_proto_msgTypes[85]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceRequest) ProtoMessage() {}
func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[85]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead.
func (*ExposeServiceRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{85}
}
func (x *ExposeServiceRequest) GetPort() uint32 {
if x != nil {
return x.Port
}
return 0
}
func (x *ExposeServiceRequest) GetProtocol() ExposeProtocol {
if x != nil {
return x.Protocol
}
return ExposeProtocol_EXPOSE_HTTP
}
func (x *ExposeServiceRequest) GetPin() string {
if x != nil {
return x.Pin
}
return ""
}
func (x *ExposeServiceRequest) GetPassword() string {
if x != nil {
return x.Password
}
return ""
}
func (x *ExposeServiceRequest) GetUserGroups() []string {
if x != nil {
return x.UserGroups
}
return nil
}
func (x *ExposeServiceRequest) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
func (x *ExposeServiceRequest) GetNamePrefix() string {
if x != nil {
return x.NamePrefix
}
return ""
}
type ExposeServiceEvent struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to Event:
//
// *ExposeServiceEvent_Ready
Event isExposeServiceEvent_Event `protobuf_oneof:"event"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceEvent) Reset() {
*x = ExposeServiceEvent{}
mi := &file_daemon_proto_msgTypes[86]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceEvent) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceEvent) ProtoMessage() {}
func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[86]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead.
func (*ExposeServiceEvent) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{86}
}
func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event {
if x != nil {
return x.Event
}
return nil
}
func (x *ExposeServiceEvent) GetReady() *ExposeServiceReady {
if x != nil {
if x, ok := x.Event.(*ExposeServiceEvent_Ready); ok {
return x.Ready
}
}
return nil
}
type isExposeServiceEvent_Event interface {
isExposeServiceEvent_Event()
}
type ExposeServiceEvent_Ready struct {
Ready *ExposeServiceReady `protobuf:"bytes,1,opt,name=ready,proto3,oneof"`
}
func (*ExposeServiceEvent_Ready) isExposeServiceEvent_Event() {}
type ExposeServiceReady struct {
state protoimpl.MessageState `protogen:"open.v1"`
ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"`
ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"`
Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceReady) Reset() {
*x = ExposeServiceReady{}
mi := &file_daemon_proto_msgTypes[87]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceReady) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceReady) ProtoMessage() {}
func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[87]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead.
func (*ExposeServiceReady) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{87}
}
func (x *ExposeServiceReady) GetServiceName() string {
if x != nil {
return x.ServiceName
}
return ""
}
func (x *ExposeServiceReady) GetServiceUrl() string {
if x != nil {
return x.ServiceUrl
}
return ""
}
func (x *ExposeServiceReady) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
type PortInfo_Range struct { type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -5610,7 +5880,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() { func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{} *x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[86] mi := &file_daemon_proto_msgTypes[89]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@@ -5622,7 +5892,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {} func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[86] mi := &file_daemon_proto_msgTypes[89]
if x != nil { if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@@ -6149,7 +6419,25 @@ const file_daemon_proto_rawDesc = "" +
"\x16InstallerResultRequest\"O\n" + "\x16InstallerResultRequest\"O\n" +
"\x17InstallerResultResponse\x12\x18\n" + "\x17InstallerResultResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\xe6\x01\n" +
"\x14ExposeServiceRequest\x12\x12\n" +
"\x04port\x18\x01 \x01(\rR\x04port\x122\n" +
"\bprotocol\x18\x02 \x01(\x0e2\x16.daemon.ExposeProtocolR\bprotocol\x12\x10\n" +
"\x03pin\x18\x03 \x01(\tR\x03pin\x12\x1a\n" +
"\bpassword\x18\x04 \x01(\tR\bpassword\x12\x1f\n" +
"\vuser_groups\x18\x05 \x03(\tR\n" +
"userGroups\x12\x16\n" +
"\x06domain\x18\x06 \x01(\tR\x06domain\x12\x1f\n" +
"\vname_prefix\x18\a \x01(\tR\n" +
"namePrefix\"Q\n" +
"\x12ExposeServiceEvent\x122\n" +
"\x05ready\x18\x01 \x01(\v2\x1a.daemon.ExposeServiceReadyH\x00R\x05readyB\a\n" +
"\x05event\"p\n" +
"\x12ExposeServiceReady\x12!\n" +
"\fservice_name\x18\x01 \x01(\tR\vserviceName\x12\x1f\n" +
"\vservice_url\x18\x02 \x01(\tR\n" +
"serviceUrl\x12\x16\n" +
"\x06domain\x18\x03 \x01(\tR\x06domain*b\n" +
"\bLogLevel\x12\v\n" + "\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" + "\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" +
@@ -6158,7 +6446,14 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" + "\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" + "\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xdd\x14\n" + "\x05TRACE\x10\a*S\n" +
"\x0eExposeProtocol\x12\x0f\n" +
"\vEXPOSE_HTTP\x10\x00\x12\x10\n" +
"\fEXPOSE_HTTPS\x10\x01\x12\x0e\n" +
"\n" +
"EXPOSE_TCP\x10\x02\x12\x0e\n" +
"\n" +
"EXPOSE_UDP\x10\x032\xac\x15\n" +
"\rDaemonService\x126\n" + "\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -6197,7 +6492,8 @@ const file_daemon_proto_rawDesc = "" +
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" + "\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" + "\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" + "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3" "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00\x12M\n" +
"\rExposeService\x12\x1c.daemon.ExposeServiceRequest\x1a\x1a.daemon.ExposeServiceEvent\"\x000\x01B\bZ\x06/protob\x06proto3"
var ( var (
file_daemon_proto_rawDescOnce sync.Once file_daemon_proto_rawDescOnce sync.Once
@@ -6211,214 +6507,222 @@ func file_daemon_proto_rawDescGZIP() []byte {
return file_daemon_proto_rawDescData return file_daemon_proto_rawDescData
} }
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88) var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 91)
var file_daemon_proto_goTypes = []any{ var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel (LogLevel)(0), // 0: daemon.LogLevel
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType (ExposeProtocol)(0), // 1: daemon.ExposeProtocol
(SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity (OSLifecycleRequest_CycleType)(0), // 2: daemon.OSLifecycleRequest.CycleType
(SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category (SystemEvent_Severity)(0), // 3: daemon.SystemEvent.Severity
(*EmptyRequest)(nil), // 4: daemon.EmptyRequest (SystemEvent_Category)(0), // 4: daemon.SystemEvent.Category
(*OSLifecycleRequest)(nil), // 5: daemon.OSLifecycleRequest (*EmptyRequest)(nil), // 5: daemon.EmptyRequest
(*OSLifecycleResponse)(nil), // 6: daemon.OSLifecycleResponse (*OSLifecycleRequest)(nil), // 6: daemon.OSLifecycleRequest
(*LoginRequest)(nil), // 7: daemon.LoginRequest (*OSLifecycleResponse)(nil), // 7: daemon.OSLifecycleResponse
(*LoginResponse)(nil), // 8: daemon.LoginResponse (*LoginRequest)(nil), // 8: daemon.LoginRequest
(*WaitSSOLoginRequest)(nil), // 9: daemon.WaitSSOLoginRequest (*LoginResponse)(nil), // 9: daemon.LoginResponse
(*WaitSSOLoginResponse)(nil), // 10: daemon.WaitSSOLoginResponse (*WaitSSOLoginRequest)(nil), // 10: daemon.WaitSSOLoginRequest
(*UpRequest)(nil), // 11: daemon.UpRequest (*WaitSSOLoginResponse)(nil), // 11: daemon.WaitSSOLoginResponse
(*UpResponse)(nil), // 12: daemon.UpResponse (*UpRequest)(nil), // 12: daemon.UpRequest
(*StatusRequest)(nil), // 13: daemon.StatusRequest (*UpResponse)(nil), // 13: daemon.UpResponse
(*StatusResponse)(nil), // 14: daemon.StatusResponse (*StatusRequest)(nil), // 14: daemon.StatusRequest
(*DownRequest)(nil), // 15: daemon.DownRequest (*StatusResponse)(nil), // 15: daemon.StatusResponse
(*DownResponse)(nil), // 16: daemon.DownResponse (*DownRequest)(nil), // 16: daemon.DownRequest
(*GetConfigRequest)(nil), // 17: daemon.GetConfigRequest (*DownResponse)(nil), // 17: daemon.DownResponse
(*GetConfigResponse)(nil), // 18: daemon.GetConfigResponse (*GetConfigRequest)(nil), // 18: daemon.GetConfigRequest
(*PeerState)(nil), // 19: daemon.PeerState (*GetConfigResponse)(nil), // 19: daemon.GetConfigResponse
(*LocalPeerState)(nil), // 20: daemon.LocalPeerState (*PeerState)(nil), // 20: daemon.PeerState
(*SignalState)(nil), // 21: daemon.SignalState (*LocalPeerState)(nil), // 21: daemon.LocalPeerState
(*ManagementState)(nil), // 22: daemon.ManagementState (*SignalState)(nil), // 22: daemon.SignalState
(*RelayState)(nil), // 23: daemon.RelayState (*ManagementState)(nil), // 23: daemon.ManagementState
(*NSGroupState)(nil), // 24: daemon.NSGroupState (*RelayState)(nil), // 24: daemon.RelayState
(*SSHSessionInfo)(nil), // 25: daemon.SSHSessionInfo (*NSGroupState)(nil), // 25: daemon.NSGroupState
(*SSHServerState)(nil), // 26: daemon.SSHServerState (*SSHSessionInfo)(nil), // 26: daemon.SSHSessionInfo
(*FullStatus)(nil), // 27: daemon.FullStatus (*SSHServerState)(nil), // 27: daemon.SSHServerState
(*ListNetworksRequest)(nil), // 28: daemon.ListNetworksRequest (*FullStatus)(nil), // 28: daemon.FullStatus
(*ListNetworksResponse)(nil), // 29: daemon.ListNetworksResponse (*ListNetworksRequest)(nil), // 29: daemon.ListNetworksRequest
(*SelectNetworksRequest)(nil), // 30: daemon.SelectNetworksRequest (*ListNetworksResponse)(nil), // 30: daemon.ListNetworksResponse
(*SelectNetworksResponse)(nil), // 31: daemon.SelectNetworksResponse (*SelectNetworksRequest)(nil), // 31: daemon.SelectNetworksRequest
(*IPList)(nil), // 32: daemon.IPList (*SelectNetworksResponse)(nil), // 32: daemon.SelectNetworksResponse
(*Network)(nil), // 33: daemon.Network (*IPList)(nil), // 33: daemon.IPList
(*PortInfo)(nil), // 34: daemon.PortInfo (*Network)(nil), // 34: daemon.Network
(*ForwardingRule)(nil), // 35: daemon.ForwardingRule (*PortInfo)(nil), // 35: daemon.PortInfo
(*ForwardingRulesResponse)(nil), // 36: daemon.ForwardingRulesResponse (*ForwardingRule)(nil), // 36: daemon.ForwardingRule
(*DebugBundleRequest)(nil), // 37: daemon.DebugBundleRequest (*ForwardingRulesResponse)(nil), // 37: daemon.ForwardingRulesResponse
(*DebugBundleResponse)(nil), // 38: daemon.DebugBundleResponse (*DebugBundleRequest)(nil), // 38: daemon.DebugBundleRequest
(*GetLogLevelRequest)(nil), // 39: daemon.GetLogLevelRequest (*DebugBundleResponse)(nil), // 39: daemon.DebugBundleResponse
(*GetLogLevelResponse)(nil), // 40: daemon.GetLogLevelResponse (*GetLogLevelRequest)(nil), // 40: daemon.GetLogLevelRequest
(*SetLogLevelRequest)(nil), // 41: daemon.SetLogLevelRequest (*GetLogLevelResponse)(nil), // 41: daemon.GetLogLevelResponse
(*SetLogLevelResponse)(nil), // 42: daemon.SetLogLevelResponse (*SetLogLevelRequest)(nil), // 42: daemon.SetLogLevelRequest
(*State)(nil), // 43: daemon.State (*SetLogLevelResponse)(nil), // 43: daemon.SetLogLevelResponse
(*ListStatesRequest)(nil), // 44: daemon.ListStatesRequest (*State)(nil), // 44: daemon.State
(*ListStatesResponse)(nil), // 45: daemon.ListStatesResponse (*ListStatesRequest)(nil), // 45: daemon.ListStatesRequest
(*CleanStateRequest)(nil), // 46: daemon.CleanStateRequest (*ListStatesResponse)(nil), // 46: daemon.ListStatesResponse
(*CleanStateResponse)(nil), // 47: daemon.CleanStateResponse (*CleanStateRequest)(nil), // 47: daemon.CleanStateRequest
(*DeleteStateRequest)(nil), // 48: daemon.DeleteStateRequest (*CleanStateResponse)(nil), // 48: daemon.CleanStateResponse
(*DeleteStateResponse)(nil), // 49: daemon.DeleteStateResponse (*DeleteStateRequest)(nil), // 49: daemon.DeleteStateRequest
(*SetSyncResponsePersistenceRequest)(nil), // 50: daemon.SetSyncResponsePersistenceRequest (*DeleteStateResponse)(nil), // 50: daemon.DeleteStateResponse
(*SetSyncResponsePersistenceResponse)(nil), // 51: daemon.SetSyncResponsePersistenceResponse (*SetSyncResponsePersistenceRequest)(nil), // 51: daemon.SetSyncResponsePersistenceRequest
(*TCPFlags)(nil), // 52: daemon.TCPFlags (*SetSyncResponsePersistenceResponse)(nil), // 52: daemon.SetSyncResponsePersistenceResponse
(*TracePacketRequest)(nil), // 53: daemon.TracePacketRequest (*TCPFlags)(nil), // 53: daemon.TCPFlags
(*TraceStage)(nil), // 54: daemon.TraceStage (*TracePacketRequest)(nil), // 54: daemon.TracePacketRequest
(*TracePacketResponse)(nil), // 55: daemon.TracePacketResponse (*TraceStage)(nil), // 55: daemon.TraceStage
(*SubscribeRequest)(nil), // 56: daemon.SubscribeRequest (*TracePacketResponse)(nil), // 56: daemon.TracePacketResponse
(*SystemEvent)(nil), // 57: daemon.SystemEvent (*SubscribeRequest)(nil), // 57: daemon.SubscribeRequest
(*GetEventsRequest)(nil), // 58: daemon.GetEventsRequest (*SystemEvent)(nil), // 58: daemon.SystemEvent
(*GetEventsResponse)(nil), // 59: daemon.GetEventsResponse (*GetEventsRequest)(nil), // 59: daemon.GetEventsRequest
(*SwitchProfileRequest)(nil), // 60: daemon.SwitchProfileRequest (*GetEventsResponse)(nil), // 60: daemon.GetEventsResponse
(*SwitchProfileResponse)(nil), // 61: daemon.SwitchProfileResponse (*SwitchProfileRequest)(nil), // 61: daemon.SwitchProfileRequest
(*SetConfigRequest)(nil), // 62: daemon.SetConfigRequest (*SwitchProfileResponse)(nil), // 62: daemon.SwitchProfileResponse
(*SetConfigResponse)(nil), // 63: daemon.SetConfigResponse (*SetConfigRequest)(nil), // 63: daemon.SetConfigRequest
(*AddProfileRequest)(nil), // 64: daemon.AddProfileRequest (*SetConfigResponse)(nil), // 64: daemon.SetConfigResponse
(*AddProfileResponse)(nil), // 65: daemon.AddProfileResponse (*AddProfileRequest)(nil), // 65: daemon.AddProfileRequest
(*RemoveProfileRequest)(nil), // 66: daemon.RemoveProfileRequest (*AddProfileResponse)(nil), // 66: daemon.AddProfileResponse
(*RemoveProfileResponse)(nil), // 67: daemon.RemoveProfileResponse (*RemoveProfileRequest)(nil), // 67: daemon.RemoveProfileRequest
(*ListProfilesRequest)(nil), // 68: daemon.ListProfilesRequest (*RemoveProfileResponse)(nil), // 68: daemon.RemoveProfileResponse
(*ListProfilesResponse)(nil), // 69: daemon.ListProfilesResponse (*ListProfilesRequest)(nil), // 69: daemon.ListProfilesRequest
(*Profile)(nil), // 70: daemon.Profile (*ListProfilesResponse)(nil), // 70: daemon.ListProfilesResponse
(*GetActiveProfileRequest)(nil), // 71: daemon.GetActiveProfileRequest (*Profile)(nil), // 71: daemon.Profile
(*GetActiveProfileResponse)(nil), // 72: daemon.GetActiveProfileResponse (*GetActiveProfileRequest)(nil), // 72: daemon.GetActiveProfileRequest
(*LogoutRequest)(nil), // 73: daemon.LogoutRequest (*GetActiveProfileResponse)(nil), // 73: daemon.GetActiveProfileResponse
(*LogoutResponse)(nil), // 74: daemon.LogoutResponse (*LogoutRequest)(nil), // 74: daemon.LogoutRequest
(*GetFeaturesRequest)(nil), // 75: daemon.GetFeaturesRequest (*LogoutResponse)(nil), // 75: daemon.LogoutResponse
(*GetFeaturesResponse)(nil), // 76: daemon.GetFeaturesResponse (*GetFeaturesRequest)(nil), // 76: daemon.GetFeaturesRequest
(*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest (*GetFeaturesResponse)(nil), // 77: daemon.GetFeaturesResponse
(*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse (*GetPeerSSHHostKeyRequest)(nil), // 78: daemon.GetPeerSSHHostKeyRequest
(*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest (*GetPeerSSHHostKeyResponse)(nil), // 79: daemon.GetPeerSSHHostKeyResponse
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse (*RequestJWTAuthRequest)(nil), // 80: daemon.RequestJWTAuthRequest
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest (*RequestJWTAuthResponse)(nil), // 81: daemon.RequestJWTAuthResponse
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse (*WaitJWTTokenRequest)(nil), // 82: daemon.WaitJWTTokenRequest
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest (*WaitJWTTokenResponse)(nil), // 83: daemon.WaitJWTTokenResponse
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse (*StartCPUProfileRequest)(nil), // 84: daemon.StartCPUProfileRequest
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest (*StartCPUProfileResponse)(nil), // 85: daemon.StartCPUProfileResponse
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse (*StopCPUProfileRequest)(nil), // 86: daemon.StopCPUProfileRequest
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest (*StopCPUProfileResponse)(nil), // 87: daemon.StopCPUProfileResponse
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse (*InstallerResultRequest)(nil), // 88: daemon.InstallerResultRequest
nil, // 89: daemon.Network.ResolvedIPsEntry (*InstallerResultResponse)(nil), // 89: daemon.InstallerResultResponse
(*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range (*ExposeServiceRequest)(nil), // 90: daemon.ExposeServiceRequest
nil, // 91: daemon.SystemEvent.MetadataEntry (*ExposeServiceEvent)(nil), // 91: daemon.ExposeServiceEvent
(*durationpb.Duration)(nil), // 92: google.protobuf.Duration (*ExposeServiceReady)(nil), // 92: daemon.ExposeServiceReady
(*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp nil, // 93: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 94: daemon.PortInfo.Range
nil, // 95: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 96: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 97: google.protobuf.Timestamp
} }
var file_daemon_proto_depIdxs = []int32{ var file_daemon_proto_depIdxs = []int32{
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType 2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 96, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus 28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp 97, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp 97, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration 96, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo 26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState 22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
20, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState 21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
19, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState 20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
23, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState 24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
24, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent 58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState 27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network 34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry 93, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 94, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule 36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel 0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
43, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State 44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
52, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags 53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category 4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp 97, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 95, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent 58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 96, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile 71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 92, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
9, // 35: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest 33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
11, // 36: daemon.DaemonService.Up:input_type -> daemon.UpRequest 8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
13, // 37: daemon.DaemonService.Status:input_type -> daemon.StatusRequest 10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
15, // 38: daemon.DaemonService.Down:input_type -> daemon.DownRequest 12, // 38: daemon.DaemonService.Up:input_type -> daemon.UpRequest
17, // 39: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest 14, // 39: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
28, // 40: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest 16, // 40: daemon.DaemonService.Down:input_type -> daemon.DownRequest
30, // 41: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest 18, // 41: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
30, // 42: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest 29, // 42: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
4, // 43: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest 31, // 43: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
37, // 44: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest 31, // 44: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
39, // 45: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest 5, // 45: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
41, // 46: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest 38, // 46: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
44, // 47: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest 40, // 47: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
46, // 48: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest 42, // 48: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
48, // 49: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest 45, // 49: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
50, // 50: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest 47, // 50: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
53, // 51: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest 49, // 51: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
56, // 52: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest 51, // 52: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
58, // 53: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest 54, // 53: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
60, // 54: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest 57, // 54: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
62, // 55: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest 59, // 55: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
64, // 56: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest 61, // 56: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
66, // 57: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest 63, // 57: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
68, // 58: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest 65, // 58: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
71, // 59: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest 67, // 59: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
73, // 60: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest 69, // 60: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
75, // 61: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest 72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest 74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest 76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest 78, // 64: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest 80, // 65: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest 82, // 66: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest 84, // 67: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest 86, // 68: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse 6, // 69: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse 88, // 70: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse 90, // 71: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse 9, // 72: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse 11, // 73: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse 13, // 74: daemon.DaemonService.Up:output_type -> daemon.UpResponse
29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse 15, // 75: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse 17, // 76: daemon.DaemonService.Down:output_type -> daemon.DownResponse
31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse 19, // 77: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse 30, // 78: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse 32, // 79: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse 32, // 80: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse 37, // 81: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse 39, // 82: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse 41, // 83: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse 43, // 84: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse 46, // 85: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse 48, // 86: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent 50, // 87: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse 52, // 88: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse 56, // 89: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse 58, // 90: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse 60, // 91: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse 62, // 92: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse 64, // 93: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse 66, // 94: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse 68, // 95: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse 70, // 96: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse 73, // 97: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse 75, // 98: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse 77, // 99: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse 79, // 100: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse 81, // 101: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse 83, // 102: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse 85, // 103: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
69, // [69:104] is the sub-list for method output_type 87, // 104: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
34, // [34:69] is the sub-list for method input_type 7, // 105: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
34, // [34:34] is the sub-list for extension type_name 89, // 106: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
34, // [34:34] is the sub-list for extension extendee 91, // 107: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
0, // [0:34] is the sub-list for field type_name 72, // [72:108] is the sub-list for method output_type
36, // [36:72] is the sub-list for method input_type
36, // [36:36] is the sub-list for extension type_name
36, // [36:36] is the sub-list for extension extendee
0, // [0:36] is the sub-list for field type_name
} }
func init() { file_daemon_proto_init() } func init() { file_daemon_proto_init() }
@@ -6439,13 +6743,16 @@ func file_daemon_proto_init() {
file_daemon_proto_msgTypes[58].OneofWrappers = []any{} file_daemon_proto_msgTypes[58].OneofWrappers = []any{}
file_daemon_proto_msgTypes[69].OneofWrappers = []any{} file_daemon_proto_msgTypes[69].OneofWrappers = []any{}
file_daemon_proto_msgTypes[75].OneofWrappers = []any{} file_daemon_proto_msgTypes[75].OneofWrappers = []any{}
file_daemon_proto_msgTypes[86].OneofWrappers = []any{
(*ExposeServiceEvent_Ready)(nil),
}
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{ File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 4, NumEnums: 5,
NumMessages: 88, NumMessages: 91,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@@ -103,6 +103,9 @@ service DaemonService {
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {} rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
// ExposeService exposes a local port via the NetBird reverse proxy
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
} }
@@ -801,3 +804,32 @@ message InstallerResultResponse {
bool success = 1; bool success = 1;
string errorMsg = 2; string errorMsg = 2;
} }
enum ExposeProtocol {
EXPOSE_HTTP = 0;
EXPOSE_HTTPS = 1;
EXPOSE_TCP = 2;
EXPOSE_UDP = 3;
}
message ExposeServiceRequest {
uint32 port = 1;
ExposeProtocol protocol = 2;
string pin = 3;
string password = 4;
repeated string user_groups = 5;
string domain = 6;
string name_prefix = 7;
}
message ExposeServiceEvent {
oneof event {
ExposeServiceReady ready = 1;
}
}
message ExposeServiceReady {
string service_name = 1;
string service_url = 2;
string domain = 3;
}

View File

@@ -76,6 +76,8 @@ type DaemonServiceClient interface {
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error)
} }
type daemonServiceClient struct { type daemonServiceClient struct {
@@ -424,6 +426,38 @@ func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *Instal
return out, nil return out, nil
} }
func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) {
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/ExposeService", opts...)
if err != nil {
return nil, err
}
x := &daemonServiceExposeServiceClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type DaemonService_ExposeServiceClient interface {
Recv() (*ExposeServiceEvent, error)
grpc.ClientStream
}
type daemonServiceExposeServiceClient struct {
grpc.ClientStream
}
func (x *daemonServiceExposeServiceClient) Recv() (*ExposeServiceEvent, error) {
m := new(ExposeServiceEvent)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// DaemonServiceServer is the server API for DaemonService service. // DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer // All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility // for forward compatibility
@@ -486,6 +520,8 @@ type DaemonServiceServer interface {
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error
mustEmbedUnimplementedDaemonServiceServer() mustEmbedUnimplementedDaemonServiceServer()
} }
@@ -598,6 +634,9 @@ func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLi
func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) { func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
} }
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error {
return status.Errorf(codes.Unimplemented, "method ExposeService not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1244,6 +1283,27 @@ func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Cont
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(ExposeServiceRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(DaemonServiceServer).ExposeService(m, &daemonServiceExposeServiceServer{stream})
}
type DaemonService_ExposeServiceServer interface {
Send(*ExposeServiceEvent) error
grpc.ServerStream
}
type daemonServiceExposeServiceServer struct {
grpc.ServerStream
}
func (x *daemonServiceExposeServiceServer) Send(m *ExposeServiceEvent) error {
return x.ServerStream.SendMsg(m)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@@ -1394,6 +1454,11 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
Handler: _DaemonService_SubscribeEvents_Handler, Handler: _DaemonService_SubscribeEvents_Handler,
ServerStreams: true, ServerStreams: true,
}, },
{
StreamName: "ExposeService",
Handler: _DaemonService_ExposeService_Handler,
ServerStreams: true,
},
}, },
Metadata: "daemon.proto", Metadata: "daemon.proto",
} }

View File

@@ -1,77 +0,0 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
switch req.GetType() {
case proto.OSLifecycleRequest_WAKEUP:
return s.handleWakeUp(callerCtx)
case proto.OSLifecycleRequest_SLEEP:
return s.handleSleep(callerCtx)
default:
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
}
return &proto.OSLifecycleResponse{}, nil
}
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
if !s.sleepTriggeredDown.Load() {
log.Info("skipping up because wasn't sleep down")
return &proto.OSLifecycleResponse{}, nil
}
// avoid other wakeup runs if sleep didn't make the computer sleep
s.sleepTriggeredDown.Store(false)
log.Info("running up after wake up")
_, err := s.Up(callerCtx, &proto.UpRequest{})
if err != nil {
log.Errorf("running up failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
log.Info("running up command executed successfully")
return &proto.OSLifecycleResponse{}, nil
}
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
s.mutex.Lock()
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, err
}
if status != internal.StatusConnecting && status != internal.StatusConnected {
log.Infof("skipping setting the agent down because status is %s", status)
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, nil
}
s.mutex.Unlock()
log.Info("running down after system started sleeping")
_, err = s.Down(callerCtx, &proto.DownRequest{})
if err != nil {
log.Errorf("running down failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
s.sleepTriggeredDown.Store(true)
log.Info("running down executed successfully")
return &proto.OSLifecycleResponse{}, nil
}

View File

@@ -1,219 +0,0 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
func newTestServer() *Server {
ctx := internal.CtxInitState(context.Background())
return &Server{
rootCtx: ctx,
statusRecorder: peer.NewRecorder(""),
}
}
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
s := newTestServer()
// sleepTriggeredDown is false by default
assert.False(t, s.sleepTriggeredDown.Load())
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusNeedsLogin)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnecting)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnected)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
}
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
s := newTestServer()
// Manually set the flag to simulate prior sleep down
s.sleepTriggeredDown.Store(true)
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
}
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
s := newTestServer()
// First wakeup without prior sleep - should be no-op
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
// Simulate prior sleep
s.sleepTriggeredDown.Store(true)
// First wakeup after sleep - should reset flag
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load())
// Second wakeup - should be no-op
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
}
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
s := newTestServer()
resp, err := s.handleWakeUp(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
}
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
s := newTestServer()
s.sleepTriggeredDown.Store(true)
// Even if Up fails, flag should be reset
_, _ = s.handleWakeUp(context.Background())
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
}
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Idle", internal.StatusIdle},
{"NeedsLogin", internal.StatusNeedsLogin},
{"LoginFailed", internal.StatusLoginFailed},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
resp, err := s.handleSleep(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
})
}
}
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Connecting", internal.StatusConnecting},
{"Connected", internal.StatusConnected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.handleSleep(ctx)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.True(t, s.sleepTriggeredDown.Load())
})
}
}

View File

@@ -21,7 +21,9 @@ import (
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client" mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
@@ -85,8 +87,7 @@ type Server struct {
profilesDisabled bool profilesDisabled bool
updateSettingsDisabled bool updateSettingsDisabled bool
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down sleepHandler *sleephandler.SleepHandler
sleepTriggeredDown atomic.Bool
jwtCache *jwtCache jwtCache *jwtCache
} }
@@ -100,7 +101,7 @@ type oauthAuthFlow struct {
// New server instance constructor. // New server instance constructor.
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server { func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
return &Server{ s := &Server{
rootCtx: ctx, rootCtx: ctx,
logFile: logFile, logFile: logFile,
persistSyncResponse: true, persistSyncResponse: true,
@@ -110,6 +111,10 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
updateSettingsDisabled: updateSettingsDisabled, updateSettingsDisabled: updateSettingsDisabled,
jwtCache: newJWTCache(), jwtCache: newJWTCache(),
} }
agent := &serverAgent{s}
s.sleepHandler = sleephandler.New(agent)
return s
} }
func (s *Server) Start() error { func (s *Server) Start() error {
@@ -636,8 +641,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
return s.waitForUp(callerCtx) return s.waitForUp(callerCtx)
} }
defer s.mutex.Unlock()
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
log.Warnf(errRestoreResidualState, err) log.Warnf(errRestoreResidualState, err)
} }
@@ -649,10 +652,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
// not in the progress or already successfully established connection. // not in the progress or already successfully established connection.
status, err := state.Status() status, err := state.Status()
if err != nil { if err != nil {
s.mutex.Unlock()
return nil, err return nil, err
} }
if status != internal.StatusIdle { if status != internal.StatusIdle {
s.mutex.Unlock()
return nil, fmt.Errorf("up already in progress: current status %s", status) return nil, fmt.Errorf("up already in progress: current status %s", status)
} }
@@ -669,17 +674,20 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.actCancel = cancel s.actCancel = cancel
if s.config == nil { if s.config == nil {
s.mutex.Unlock()
return nil, fmt.Errorf("config is not defined, please call login command first") return nil, fmt.Errorf("config is not defined, please call login command first")
} }
activeProf, err := s.profileManager.GetActiveProfileState() activeProf, err := s.profileManager.GetActiveProfileState()
if err != nil { if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile state: %v", err) log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err) return nil, fmt.Errorf("failed to get active profile state: %w", err)
} }
if msg != nil && msg.ProfileName != nil { if msg != nil && msg.ProfileName != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil { if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
s.mutex.Unlock()
log.Errorf("failed to switch profile: %v", err) log.Errorf("failed to switch profile: %v", err)
return nil, fmt.Errorf("failed to switch profile: %w", err) return nil, fmt.Errorf("failed to switch profile: %w", err)
} }
@@ -687,6 +695,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
activeProf, err = s.profileManager.GetActiveProfileState() activeProf, err = s.profileManager.GetActiveProfileState()
if err != nil { if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile state: %v", err) log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err) return nil, fmt.Errorf("failed to get active profile state: %w", err)
} }
@@ -695,6 +704,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
config, _, err := s.getConfig(activeProf) config, _, err := s.getConfig(activeProf)
if err != nil { if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile config: %v", err) log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err) return nil, fmt.Errorf("failed to get active profile config: %w", err)
} }
@@ -713,6 +723,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
} }
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan) go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
s.mutex.Unlock()
return s.waitForUp(callerCtx) return s.waitForUp(callerCtx)
} }
@@ -838,15 +849,27 @@ func (s *Server) cleanupConnection() error {
if s.actCancel == nil { if s.actCancel == nil {
return ErrServiceNotUp return ErrServiceNotUp
} }
// Capture the engine reference before cancelling the context.
// After actCancel(), the connectWithRetryRuns goroutine wakes up
// and sets connectClient.engine = nil, causing connectClient.Stop()
// to skip the engine shutdown entirely.
var engine *internal.Engine
if s.connectClient != nil {
engine = s.connectClient.Engine()
}
s.actCancel() s.actCancel()
if s.connectClient == nil { if s.connectClient == nil {
return nil return nil
} }
if err := s.connectClient.Stop(); err != nil { if engine != nil {
if err := engine.Stop(); err != nil {
return err return err
} }
}
s.connectClient = nil s.connectClient = nil
s.isSessionActive.Store(false) s.isSessionActive.Store(false)
@@ -1312,6 +1335,60 @@ func (s *Server) WaitJWTToken(
}, nil }, nil
} }
// ExposeService exposes a local port via the NetBird reverse proxy.
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
s.mutex.Lock()
if !s.clientRunning {
s.mutex.Unlock()
return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first")
}
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
}
mgr := engine.GetExposeManager()
if mgr == nil {
return gstatus.Errorf(codes.Internal, "expose manager not available")
}
ctx := srv.Context()
exposeCtx, exposeCancel := context.WithTimeout(ctx, 30*time.Second)
defer exposeCancel()
mgmReq := expose.NewRequest(req)
result, err := mgr.Expose(exposeCtx, *mgmReq)
if err != nil {
return err
}
if err := srv.Send(&proto.ExposeServiceEvent{
Event: &proto.ExposeServiceEvent_Ready{
Ready: &proto.ExposeServiceReady{
ServiceName: result.ServiceName,
ServiceUrl: result.ServiceURL,
Domain: result.Domain,
},
},
}); err != nil {
return err
}
err = mgr.KeepAlive(ctx, result.Domain)
if err != nil {
return err
}
return nil
}
func isUnixRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false return false

46
client/server/sleep.go Normal file
View File

@@ -0,0 +1,46 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
type serverAgent struct {
s *Server
}
func (a *serverAgent) Up(ctx context.Context) error {
_, err := a.s.Up(ctx, &proto.UpRequest{})
return err
}
func (a *serverAgent) Down(ctx context.Context) error {
_, err := a.s.Down(ctx, &proto.DownRequest{})
return err
}
func (a *serverAgent) Status() (internal.StatusType, error) {
return internal.CtxGetState(a.s.rootCtx).Status()
}
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
switch req.GetType() {
case proto.OSLifecycleRequest_WAKEUP:
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
return &proto.OSLifecycleResponse{}, err
}
case proto.OSLifecycleRequest_SLEEP:
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
return &proto.OSLifecycleResponse{}, err
}
default:
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
}
return &proto.OSLifecycleResponse{}, nil
}

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
@@ -268,7 +269,7 @@ func getDefaultDaemonAddr() string {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
return DefaultDaemonAddrWindows return DefaultDaemonAddrWindows
} }
return DefaultDaemonAddr return daemonaddr.ResolveUnixDaemonAddr(DefaultDaemonAddr)
} }
// DialOptions contains options for SSH connections // DialOptions contains options for SSH connections

View File

@@ -46,8 +46,10 @@ const (
cmdSFTP = "<sftp>" cmdSFTP = "<sftp>"
cmdNonInteractive = "<idle>" cmdNonInteractive = "<idle>"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server.
DefaultJWTMaxTokenAge = 5 * 60 // Set to 10 minutes to accommodate identity providers like Azure Entra ID
// that backdate the iat claim by up to 5 minutes.
DefaultJWTMaxTokenAge = 10 * 60
) )
var ( var (

View File

@@ -7,6 +7,7 @@ import (
"net/netip" "net/netip"
"os" "os"
"path" "path"
"path/filepath"
"strings" "strings"
"time" "time"
@@ -70,6 +71,8 @@ type ServerConfig struct {
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"` DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
Auth AuthConfig `yaml:"auth"` Auth AuthConfig `yaml:"auth"`
Store StoreConfig `yaml:"store"` Store StoreConfig `yaml:"store"`
ActivityStore StoreConfig `yaml:"activityStore"`
AuthStore StoreConfig `yaml:"authStore"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"` ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
} }
@@ -171,6 +174,7 @@ type StoreConfig struct {
Engine string `yaml:"engine"` Engine string `yaml:"engine"`
EncryptionKey string `yaml:"encryptionKey"` EncryptionKey string `yaml:"encryptionKey"`
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
File string `yaml:"file"` // SQLite database file path (optional, defaults to dataDir)
} }
// ReverseProxyConfig contains reverse proxy settings // ReverseProxyConfig contains reverse proxy settings
@@ -532,6 +536,74 @@ func stripSignalProtocol(uri string) string {
return uri return uri
} }
func buildRelayConfig(relays RelaysConfig) (*nbconfig.Relay, error) {
var ttl time.Duration
if relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", relays.CredentialsTTL, err)
}
}
return &nbconfig.Relay{
Addresses: relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: relays.Secret,
}, nil
}
// buildEmbeddedIdPConfig builds the embedded IdP configuration.
// authStore overrides auth.storage when set.
func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.EmbeddedIdPConfig, error) {
authStorageType := mgmt.Auth.Storage.Type
authStorageDSN := c.Server.AuthStore.DSN
if c.Server.AuthStore.Engine != "" {
authStorageType = c.Server.AuthStore.Engine
}
if authStorageType == "" {
authStorageType = "sqlite3"
}
authStorageFile := ""
if authStorageType == "postgres" {
if authStorageDSN == "" {
return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres")
}
} else {
authStorageFile = path.Join(mgmt.DataDir, "idp.db")
if c.Server.AuthStore.File != "" {
authStorageFile = c.Server.AuthStore.File
if !filepath.IsAbs(authStorageFile) {
authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile)
}
}
}
cfg := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: authStorageType,
Config: idp.EmbeddedStorageTypeConfig{
File: authStorageFile,
DSN: authStorageDSN,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
cfg.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password,
}
}
return cfg, nil
}
// ToManagementConfig converts CombinedConfig to management server config // ToManagementConfig converts CombinedConfig to management server config
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) { func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
mgmt := c.Management mgmt := c.Management
@@ -550,19 +622,11 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
// Build relay config // Build relay config
var relayConfig *nbconfig.Relay var relayConfig *nbconfig.Relay
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" { if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
var ttl time.Duration relay, err := buildRelayConfig(mgmt.Relays)
if mgmt.Relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err) return nil, err
}
}
relayConfig = &nbconfig.Relay{
Addresses: mgmt.Relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: mgmt.Relays.Secret,
} }
relayConfig = relay
} }
// Build signal config // Build signal config
@@ -598,31 +662,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
httpConfig := &nbconfig.HttpServerConfig{} httpConfig := &nbconfig.HttpServerConfig{}
// Build embedded IDP config (always enabled in combined server) // Build embedded IDP config (always enabled in combined server)
storageFile := mgmt.Auth.Storage.File embeddedIdP, err := c.buildEmbeddedIdPConfig(mgmt)
if storageFile == "" { if err != nil {
storageFile = path.Join(mgmt.DataDir, "idp.db") return nil, err
}
embeddedIdP := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: mgmt.Auth.Storage.Type,
Config: idp.EmbeddedStorageTypeConfig{
File: storageFile,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
embeddedIdP.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
}
} }
// Set HTTP config fields for embedded IDP // Set HTTP config fields for embedded IDP

View File

@@ -140,6 +140,23 @@ func initializeConfig() error {
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn) os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
} }
} }
if file := config.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
if engine := config.Server.ActivityStore.Engine; engine != "" {
engineLower := strings.ToLower(engine)
if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" {
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
}
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
if dsn := config.Server.ActivityStore.DSN; dsn != "" {
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
}
}
if file := config.Server.ActivityStore.File; file != "" {
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
}
log.Infof("Starting combined NetBird server") log.Infof("Starting combined NetBird server")
logConfig(config) logConfig(config)
@@ -476,9 +493,6 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) { func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
mgmt := cfg.Management mgmt := cfg.Management
dnsDomain := mgmt.DnsDomain
singleAccModeDomain := dnsDomain
// Extract port from listen address // Extract port from listen address
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress) _, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
if err != nil { if err != nil {
@@ -490,8 +504,9 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
mgmtSrv := mgmtServer.NewServer( mgmtSrv := mgmtServer.NewServer(
&mgmtServer.Config{ &mgmtServer.Config{
NbConfig: mgmtConfig, NbConfig: mgmtConfig,
DNSDomain: dnsDomain, DNSDomain: "",
MgmtSingleAccModeDomain: singleAccModeDomain, MgmtSingleAccModeDomain: "",
AutoResolveDomains: true,
MgmtPort: mgmtPort, MgmtPort: mgmtPort,
MgmtMetricsPort: cfg.Server.MetricsPort, MgmtMetricsPort: cfg.Server.MetricsPort,
DisableMetrics: mgmt.DisableAnonymousMetrics, DisableMetrics: mgmt.DisableAnonymousMetrics,
@@ -668,8 +683,11 @@ func logEnvVars() {
if strings.HasPrefix(env, "NB_") { if strings.HasPrefix(env, "NB_") {
key, _, _ := strings.Cut(env, "=") key, _, _ := strings.Cut(env, "=")
value := os.Getenv(key) value := os.Getenv(key)
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") { keyLower := strings.ToLower(key)
if strings.Contains(keyLower, "secret") || strings.Contains(keyLower, "key") || strings.Contains(keyLower, "password") {
value = maskSecret(value) value = maskSecret(value)
} else if strings.Contains(keyLower, "dsn") {
value = maskDSNPassword(value)
} }
log.Infof(" %s=%s", key, value) log.Infof(" %s=%s", key, value)
found = true found = true

View File

@@ -42,6 +42,9 @@ func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Sto
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn) os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
} }
} }
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
datadir := cfg.Management.DataDir datadir := cfg.Management.DataDir
engine := types.Engine(cfg.Management.Store.Engine) engine := types.Engine(cfg.Management.Store.Engine)

View File

@@ -103,6 +103,19 @@ server:
engine: "sqlite" # sqlite, postgres, or mysql engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql dsn: "" # Connection string for postgres or mysql
encryptionKey: "" encryptionKey: ""
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/store.db)
# Activity events store configuration (optional, defaults to sqlite in dataDir)
# activityStore:
# engine: "sqlite" # sqlite or postgres
# dsn: "" # Connection string for postgres
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/events.db)
# Auth (embedded IdP) store configuration (optional, defaults to sqlite3 in dataDir/idp.db)
# authStore:
# engine: "sqlite3" # sqlite3 or postgres
# dsn: "" # Connection string for postgres (e.g., "host=localhost port=5432 user=postgres password=postgres dbname=netbird_idp sslmode=disable")
# file: "" # Custom SQLite file path (optional, defaults to {dataDir}/idp.db)
# Reverse proxy settings (optional) # Reverse proxy settings (optional)
# reverseProxy: # reverseProxy:

View File

@@ -5,7 +5,10 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
"net/url"
"os" "os"
"strconv"
"strings"
"time" "time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@@ -195,11 +198,175 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
return nil, fmt.Errorf("sqlite3 storage requires 'file' config") return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
} }
return (&sql.SQLite3{File: file}).Open(logger) return (&sql.SQLite3{File: file}).Open(logger)
case "postgres":
dsn, _ := s.Config["dsn"].(string)
if dsn == "" {
return nil, fmt.Errorf("postgres storage requires 'dsn' config")
}
pg, err := parsePostgresDSN(dsn)
if err != nil {
return nil, fmt.Errorf("invalid postgres DSN: %w", err)
}
return pg.Open(logger)
default: default:
return nil, fmt.Errorf("unsupported storage type: %s", s.Type) return nil, fmt.Errorf("unsupported storage type: %s", s.Type)
} }
} }
// parsePostgresDSN parses a DSN into a sql.Postgres config.
// It accepts both URI format (postgres://user:pass@host:port/dbname?sslmode=disable)
// and libpq key=value format (host=localhost port=5432 dbname=mydb), including quoted values.
func parsePostgresDSN(dsn string) (*sql.Postgres, error) {
var params map[string]string
var err error
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
params, err = parsePostgresURI(dsn)
} else {
params, err = parsePostgresKeyValue(dsn)
}
if err != nil {
return nil, err
}
host := params["host"]
if host == "" {
host = "localhost"
}
var port uint16 = 5432
if p, ok := params["port"]; ok && p != "" {
v, err := strconv.ParseUint(p, 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid port %q: %w", p, err)
}
if v == 0 {
return nil, fmt.Errorf("invalid port %q: must be non-zero", p)
}
port = uint16(v)
}
dbname := params["dbname"]
if dbname == "" {
return nil, fmt.Errorf("dbname is required in DSN")
}
pg := &sql.Postgres{
NetworkDB: sql.NetworkDB{
Host: host,
Port: port,
Database: dbname,
User: params["user"],
Password: params["password"],
},
}
if sslMode := params["sslmode"]; sslMode != "" {
switch sslMode {
case "disable", "allow", "prefer", "require", "verify-ca", "verify-full":
pg.SSL.Mode = sslMode
default:
return nil, fmt.Errorf("unsupported sslmode %q: valid values are disable, allow, prefer, require, verify-ca, verify-full", sslMode)
}
}
return pg, nil
}
// parsePostgresURI parses a postgres:// or postgresql:// URI into parameter key-value pairs.
func parsePostgresURI(dsn string) (map[string]string, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid postgres URI: %w", err)
}
params := make(map[string]string)
if u.User != nil {
params["user"] = u.User.Username()
if p, ok := u.User.Password(); ok {
params["password"] = p
}
}
if u.Hostname() != "" {
params["host"] = u.Hostname()
}
if u.Port() != "" {
params["port"] = u.Port()
}
dbname := strings.TrimPrefix(u.Path, "/")
if dbname != "" {
params["dbname"] = dbname
}
for k, v := range u.Query() {
if len(v) > 0 {
params[k] = v[0]
}
}
return params, nil
}
// parsePostgresKeyValue parses a libpq key=value DSN string, handling single-quoted values
// (e.g., password='my pass' host=localhost).
func parsePostgresKeyValue(dsn string) (map[string]string, error) {
params := make(map[string]string)
s := strings.TrimSpace(dsn)
for s != "" {
eqIdx := strings.IndexByte(s, '=')
if eqIdx < 0 {
break
}
key := strings.TrimSpace(s[:eqIdx])
value, rest, err := parseDSNValue(s[eqIdx+1:])
if err != nil {
return nil, fmt.Errorf("%w for key %q", err, key)
}
params[key] = value
s = strings.TrimSpace(rest)
}
return params, nil
}
// parseDSNValue parses the next value from a libpq key=value string positioned after the '='.
// It returns the parsed value and the remaining unparsed string.
func parseDSNValue(s string) (value, rest string, err error) {
if len(s) > 0 && s[0] == '\'' {
return parseQuotedDSNValue(s[1:])
}
// Unquoted value: read until whitespace.
idx := strings.IndexAny(s, " \t\n")
if idx < 0 {
return s, "", nil
}
return s[:idx], s[idx:], nil
}
// parseQuotedDSNValue parses a single-quoted value starting after the opening quote.
// Libpq uses ” to represent a literal single quote inside quoted values.
func parseQuotedDSNValue(s string) (value, rest string, err error) {
var buf strings.Builder
for len(s) > 0 {
if s[0] == '\'' {
if len(s) > 1 && s[1] == '\'' {
buf.WriteByte('\'')
s = s[2:]
continue
}
return buf.String(), s[1:], nil
}
buf.WriteByte(s[0])
s = s[1:]
}
return "", "", fmt.Errorf("unterminated quoted value")
}
// Validate validates the configuration // Validate validates the configuration
func (c *YAMLConfig) Validate() error { func (c *YAMLConfig) Validate() error {
if c.Issuer == "" { if c.Issuer == "" {

View File

@@ -63,6 +63,8 @@ type Controller struct {
expNewNetworkMap bool expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{} expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
} }
type bufferUpdate struct { type bufferUpdate struct {
@@ -85,6 +87,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
newNetworkMapBuilder = false newNetworkMapBuilder = false
} }
compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err)
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",") ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids)) expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids { for _, id := range ids {
@@ -108,6 +116,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
holder: types.NewHolder(), holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder, expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs, expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
} }
} }
@@ -230,9 +240,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var remotePeerNetworkMap *types.NetworkMap var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) { switch {
case c.experimentalNetworkMap(accountID):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else { case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} }
@@ -355,9 +368,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
var remotePeerNetworkMap *types.NetworkMap var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountId) { switch {
case c.experimentalNetworkMap(accountId):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else { case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} }
@@ -479,7 +495,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
} else { } else {
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers()) groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
} }
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -854,7 +875,12 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
account.InjectProxyPolicies(ctx) account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
} }
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -24,6 +24,8 @@ type AccessLogEntry struct {
Reason string Reason string
UserId string `gorm:"index"` UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"` AuthMethodUsed string `gorm:"index"`
BytesUpload int64 `gorm:"index"`
BytesDownload int64 `gorm:"index"`
} }
// FromProto creates an AccessLogEntry from a proto.AccessLog // FromProto creates an AccessLogEntry from a proto.AccessLog
@@ -39,6 +41,8 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
a.UserId = serviceLog.GetUserId() a.UserId = serviceLog.GetUserId()
a.AuthMethodUsed = serviceLog.GetAuthMechanism() a.AuthMethodUsed = serviceLog.GetAuthMechanism()
a.AccountID = serviceLog.GetAccountId() a.AccountID = serviceLog.GetAccountId()
a.BytesUpload = serviceLog.GetBytesUpload()
a.BytesDownload = serviceLog.GetBytesDownload()
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" { if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
if ip, err := netip.ParseAddr(sourceIP); err == nil { if ip, err := netip.ParseAddr(sourceIP); err == nil {
@@ -101,5 +105,7 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
AuthMethodUsed: authMethod, AuthMethodUsed: authMethod,
CountryCode: countryCode, CountryCode: countryCode,
CityName: cityName, CityName: cityName,
BytesUpload: a.BytesUpload,
BytesDownload: a.BytesDownload,
} }
} }

View File

@@ -15,3 +15,12 @@ type Domain struct {
Type Type `gorm:"-"` Type Type `gorm:"-"`
Validated bool Validated bool
} }
// EventMeta returns activity event metadata for a domain
func (d *Domain) EventMeta() map[string]any {
return map[string]any{
"domain": d.Domain,
"target_cluster": d.TargetCluster,
"validated": d.Validated,
}
}

View File

@@ -9,4 +9,5 @@ type Manager interface {
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
ValidateDomain(ctx context.Context, accountID, userID, domainID string) ValidateDomain(ctx context.Context, accountID, userID, domainID string)
GetClusterDomains() []string
} }

View File

@@ -9,6 +9,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -27,25 +29,25 @@ type store interface {
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
} }
type proxyURLProvider interface { type proxyManager interface {
GetConnectedProxyURLs() []string GetActiveClusterAddresses(ctx context.Context) ([]string, error)
} }
type Manager struct { type Manager struct {
store store store store
validator domain.Validator validator domain.Validator
proxyURLProvider proxyURLProvider proxyManager proxyManager
permissionsManager permissions.Manager permissionsManager permissions.Manager
accountManager account.Manager
} }
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager { func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
return Manager{ return Manager{
store: store, store: store,
proxyURLProvider: proxyURLProvider, proxyManager: proxyMgr,
validator: domain.Validator{ validator: domain.Validator{Resolver: net.DefaultResolver},
Resolver: net.DefaultResolver,
},
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
accountManager: accountManager,
} }
} }
@@ -67,8 +69,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
// Add connected proxy clusters as free domains. // Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io"). // The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList := m.proxyURLAllowList() allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
log.WithFields(log.Fields{ if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
log.WithContext(ctx).WithFields(log.Fields{
"accountID": accountID, "accountID": accountID,
"proxyAllowList": allowList, "proxyAllowList": allowList,
}).Debug("getting domains with proxy allow list") }).Debug("getting domains with proxy allow list")
@@ -107,7 +113,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
} }
// Verify the target cluster is in the available clusters // Verify the target cluster is in the available clusters
allowList := m.proxyURLAllowList() allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
clusterValid := false clusterValid := false
for _, cluster := range allowList { for _, cluster := range allowList {
if cluster == targetCluster { if cluster == targetCluster {
@@ -129,6 +138,9 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
if err != nil { if err != nil {
return d, fmt.Errorf("create domain in store: %w", err) return d, fmt.Errorf("create domain in store: %w", err)
} }
m.accountManager.StoreEvent(ctx, userID, d.ID, accountID, activity.DomainAdded, d.EventMeta())
return d, nil return d, nil
} }
@@ -141,10 +153,18 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
if err != nil {
return fmt.Errorf("get domain from store: %w", err)
}
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil { if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
// TODO: check for "no records" type error. Because that is a success condition. // TODO: check for "no records" type error. Because that is a success condition.
return fmt.Errorf("delete domain from store: %w", err) return fmt.Errorf("delete domain from store: %w", err)
} }
m.accountManager.StoreEvent(ctx, userID, domainID, accountID, activity.DomainDeleted, d.EventMeta())
return nil return nil
} }
@@ -211,6 +231,8 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
}).WithError(err).Error("update custom domain in store") }).WithError(err).Error("update custom domain in store")
return return
} }
m.accountManager.StoreEvent(context.Background(), userID, domainID, accountID, activity.DomainValidated, d.EventMeta())
} else { } else {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"accountID": accountID, "accountID": accountID,
@@ -221,21 +243,26 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
} }
} }
// proxyURLAllowList retrieves a list of currently connected proxies and // GetClusterDomains returns a list of proxy cluster domains.
// their URLs func (m Manager) GetClusterDomains() []string {
func (m Manager) proxyURLAllowList() []string { if m.proxyManager == nil {
var reverseProxyAddresses []string return nil
if m.proxyURLProvider != nil {
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
} }
return reverseProxyAddresses addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background())
if err != nil {
return nil
}
return addresses
} }
// DeriveClusterFromDomain determines the proxy cluster for a given domain. // DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. // For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster. // For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) { func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList := m.proxyURLAllowList() allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
if len(allowList) == 0 { if len(allowList) == 0 {
return "", fmt.Errorf("no proxy clusters available") return "", fmt.Errorf("no proxy clusters available")
} }

View File

@@ -1,539 +0,0 @@
package manager
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
}
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
clusterDeriver ClusterDeriver
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyGRPCServer: proxyGRPCServer,
clusterDeriver: clusterDeriver,
}
}
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
for _, target := range service.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Prefix.Addr().String()
case reverseproxy.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Domain
case reverseproxy.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
return nil, err
}
if err := m.persistNewService(ctx, accountID, service); err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
if m.clusterDeriver != nil {
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
service.ProxyCluster = proxyCluster
}
service.AccountID = accountID
service.InitNewRecord()
if err := service.Auth.HashSecrets(); err != nil {
return fmt.Errorf("hash secrets: %w", err)
}
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate session keys: %w", err)
}
service.SessionPrivateKey = keyPair.PrivateKey
service.SessionPublicKey = keyPair.PublicKey
return nil
}
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
})
}
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
}
return nil
}
if existingService != nil && existingService.ID != excludeServiceID {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
}
return nil
}
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := service.Auth.HashSecrets(); err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.sendServiceUpdateNotifications(service, updateInfo)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
type serviceUpdateInfo struct {
oldCluster string
domainChanged bool
serviceEnabledChanged bool
}
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
m.preserveExistingAuthSecrets(service, existingService)
m.preserveServiceMetadata(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
})
return &updateInfo, err
}
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
} else {
service.ProxyCluster = newCluster
}
}
return nil
}
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
}
}
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
switch {
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
case !service.Enabled && updateInfo.serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
case service.Enabled && updateInfo.serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
default:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
}
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
for _, target := range targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
}
}
return nil
}
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var service *reverseproxy.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
return nil
})
if err != nil {
return err
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
service.Meta.CertificateIssuedAt = time.Now()
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
}
return nil
})
}
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
service.Meta.Status = string(status)
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service status: %w", err)
}
return nil
})
}
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
}
return nil
}
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, service.AccountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return "", nil
}
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
}
if target == nil {
return "", nil
}
return target.ServiceID, nil
}

View File

@@ -1,375 +0,0 @@
package manager
import (
"context"
"errors"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
func TestInitializeServiceForCreate(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
mgr := &managerImpl{
clusterDeriver: nil,
}
service := &reverseproxy.Service{
Domain: "example.com",
Auth: reverseproxy.AuthConfig{},
}
err := mgr.initializeServiceForCreate(ctx, accountID, service)
assert.NoError(t, err)
assert.Equal(t, accountID, service.AccountID)
assert.Empty(t, service.ProxyCluster, "proxy cluster should be empty when no deriver")
assert.NotEmpty(t, service.ID, "service ID should be initialized")
assert.NotEmpty(t, service.SessionPrivateKey, "session private key should be generated")
assert.NotEmpty(t, service.SessionPublicKey, "session public key should be generated")
})
t.Run("verifies session keys are different", func(t *testing.T) {
mgr := &managerImpl{
clusterDeriver: nil,
}
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
assert.NoError(t, err1)
assert.NoError(t, err2)
assert.NotEqual(t, service1.SessionPrivateKey, service2.SessionPrivateKey, "private keys should be unique")
assert.NotEqual(t, service1.SessionPublicKey, service2.SessionPublicKey, "public keys should be unique")
})
}
func TestCheckDomainAvailable(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
tests := []struct {
name string
domain string
excludeServiceID string
setupMock func(*store.MockStore)
expectedError bool
errorType status.Type
}{
{
name: "domain available - not found",
domain: "available.com",
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "available.com").
Return(nil, status.Errorf(status.NotFound, "not found"))
},
expectedError: false,
},
{
name: "domain already exists",
domain: "exists.com",
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
},
{
name: "domain exists but excluded (same ID)",
domain: "exists.com",
excludeServiceID: "service-123",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: false,
},
{
name: "domain exists with different ID",
domain: "exists.com",
excludeServiceID: "service-456",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
},
{
name: "store error (non-NotFound)",
domain: "error.com",
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "error.com").
Return(nil, errors.New("database error"))
},
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
tt.setupMock(mockStore)
mgr := &managerImpl{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
if tt.expectedError {
require.Error(t, err)
if tt.errorType != 0 {
sErr, ok := status.FromError(err)
require.True(t, ok, "error should be a status error")
assert.Equal(t, tt.errorType, sErr.Type())
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
t.Run("empty domain", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "").
Return(nil, status.Errorf(status.NotFound, "not found"))
mgr := &managerImpl{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
assert.NoError(t, err)
})
t.Run("empty exclude ID with existing service", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "test.com").
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
mgr := &managerImpl{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
assert.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.AlreadyExists, sErr.Type())
})
t.Run("nil existing service with nil error", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "nil.com").
Return(nil, nil)
mgr := &managerImpl{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
assert.NoError(t, err)
})
}
func TestPersistNewService(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
t.Run("successful service creation with no targets", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
ID: "service-123",
Domain: "new.com",
Targets: []*reverseproxy.Target{},
}
// Mock ExecuteInTransaction to execute the function immediately
mockStore.EXPECT().
ExecuteInTransaction(ctx, gomock.Any()).
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
// Create another mock for the transaction
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "new.com").
Return(nil, status.Errorf(status.NotFound, "not found"))
txMock.EXPECT().
CreateService(ctx, service).
Return(nil)
return fn(txMock)
})
mgr := &managerImpl{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service)
assert.NoError(t, err)
})
t.Run("domain already exists", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
ID: "service-123",
Domain: "existing.com",
Targets: []*reverseproxy.Target{},
}
mockStore.EXPECT().
ExecuteInTransaction(ctx, gomock.Any()).
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "existing.com").
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
return fn(txMock)
})
mgr := &managerImpl{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service)
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.AlreadyExists, sErr.Type())
})
}
func TestPreserveExistingAuthSecrets(t *testing.T) {
mgr := &managerImpl{}
t.Run("preserve password when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
Enabled: true,
Password: "hashed-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
Enabled: true,
Password: "",
},
},
}
mgr.preserveExistingAuthSecrets(updated, existing)
assert.Equal(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
})
t.Run("preserve pin when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
Enabled: true,
Pin: "hashed-pin",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
Enabled: true,
Pin: "",
},
},
}
mgr.preserveExistingAuthSecrets(updated, existing)
assert.Equal(t, existing.Auth.PinAuth, updated.Auth.PinAuth)
})
t.Run("do not preserve when password is provided", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
Enabled: true,
Password: "old-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
Enabled: true,
Password: "new-password",
},
},
}
mgr.preserveExistingAuthSecrets(updated, existing)
assert.Equal(t, "new-password", updated.Auth.PasswordAuth.Password)
assert.NotEqual(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
})
}
func TestPreserveServiceMetadata(t *testing.T) {
mgr := &managerImpl{}
existing := &reverseproxy.Service{
Meta: reverseproxy.ServiceMeta{
CertificateIssuedAt: time.Now(),
Status: "active",
},
SessionPrivateKey: "private-key",
SessionPublicKey: "public-key",
}
updated := &reverseproxy.Service{
Domain: "updated.com",
}
mgr.preserveServiceMetadata(updated, existing)
assert.Equal(t, existing.Meta, updated.Meta)
assert.Equal(t, existing.SessionPrivateKey, updated.SessionPrivateKey)
assert.Equal(t, existing.SessionPublicKey, updated.SessionPublicKey)
}

View File

@@ -0,0 +1,36 @@
package proxy
//go:generate go run github.com/golang/mock/mockgen -package proxy -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"time"
"github.com/netbirdio/netbird/shared/management/proto"
)
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
}
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
type OIDCValidationConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
// Controller is responsible for managing proxy clusters and routing service updates.
type Controller interface {
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
GetOIDCValidationConfig() OIDCValidationConfig
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
}

View File

@@ -0,0 +1,88 @@
package manager
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/shared/management/proto"
)
// GRPCController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
type GRPCController struct {
proxyGRPCServer *nbgrpc.ProxyServiceServer
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
metrics *metrics
}
// NewGRPCController creates a new GRPCController.
func NewGRPCController(proxyGRPCServer *nbgrpc.ProxyServiceServer, meter metric.Meter) (*GRPCController, error) {
m, err := newMetrics(meter)
if err != nil {
return nil, err
}
return &GRPCController{
proxyGRPCServer: proxyGRPCServer,
metrics: m,
}, nil
}
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
func (c *GRPCController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
c.metrics.IncrementServiceUpdateSendCount(clusterAddr)
}
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
func (c *GRPCController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return c.proxyGRPCServer.GetOIDCValidationConfig()
}
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
func (c *GRPCController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
c.metrics.IncrementProxyConnectionCount(clusterAddr)
return nil
}
// UnregisterProxyFromCluster removes a proxy from a cluster.
func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
c.metrics.DecrementProxyConnectionCount(clusterAddr)
}
return nil
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)
if !ok {
return nil
}
var proxies []string
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxies = append(proxies, key.(string))
return true
})
return proxies
}

View File

@@ -0,0 +1,115 @@
package manager
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// store defines the interface for proxy persistence operations
type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
}
// Manager handles all proxy operations
type Manager struct {
store store
metrics *metrics
}
// NewManager creates a new proxy Manager
func NewManager(store store, meter metric.Meter) (*Manager, error) {
m, err := newMetrics(meter)
if err != nil {
return nil, err
}
return &Manager{
store: store,
metrics: m,
}, nil
}
// Connect registers a new proxy connection in the database
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
"clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return nil
}
// Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
}).Info("proxy disconnected")
return nil
}
// Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
}
m.metrics.IncrementProxyHeartbeatCount()
return nil
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
return addresses, nil
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}

View File

@@ -0,0 +1,74 @@
package manager
import (
"context"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
type metrics struct {
proxyConnectionCount metric.Int64UpDownCounter
serviceUpdateSendCount metric.Int64Counter
proxyHeartbeatCount metric.Int64Counter
}
func newMetrics(meter metric.Meter) (*metrics, error) {
proxyConnectionCount, err := meter.Int64UpDownCounter(
"management_proxy_connection_count",
metric.WithDescription("Number of active proxy connections"),
metric.WithUnit("{connection}"),
)
if err != nil {
return nil, err
}
serviceUpdateSendCount, err := meter.Int64Counter(
"management_proxy_service_update_send_count",
metric.WithDescription("Total number of service updates sent to proxies"),
metric.WithUnit("{update}"),
)
if err != nil {
return nil, err
}
proxyHeartbeatCount, err := meter.Int64Counter(
"management_proxy_heartbeat_count",
metric.WithDescription("Total number of proxy heartbeats received"),
metric.WithUnit("{heartbeat}"),
)
if err != nil {
return nil, err
}
return &metrics{
proxyConnectionCount: proxyConnectionCount,
serviceUpdateSendCount: serviceUpdateSendCount,
proxyHeartbeatCount: proxyHeartbeatCount,
}, nil
}
func (m *metrics) IncrementProxyConnectionCount(clusterAddr string) {
m.proxyConnectionCount.Add(context.Background(), 1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) DecrementProxyConnectionCount(clusterAddr string) {
m.proxyConnectionCount.Add(context.Background(), -1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) IncrementServiceUpdateSendCount(clusterAddr string) {
m.serviceUpdateSendCount.Add(context.Background(), 1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) IncrementProxyHeartbeatCount() {
m.proxyHeartbeatCount.Add(context.Background(), 1)
}

View File

@@ -0,0 +1,199 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./manager.go
// Package proxy is a generated GoMock package.
package proxy
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
proto "github.com/netbirdio/netbird/shared/management/proto"
)
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CleanupStale mocks base method.
func (m *MockManager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupStale", ctx, inactivityDuration)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupStale indicates an expected call of CleanupStale.
func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error)
return ret0
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
}
// Disconnect mocks base method.
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// Disconnect indicates an expected call of Disconnect.
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
}
// GetActiveClusterAddresses mocks base method.
func (m *MockManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusterAddresses", ctx)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusterAddresses indicates an expected call of GetActiveClusterAddresses.
func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// Heartbeat indicates an expected call of Heartbeat.
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
}
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller
recorder *MockControllerMockRecorder
}
// MockControllerMockRecorder is the mock recorder for MockController.
type MockControllerMockRecorder struct {
mock *MockController
}
// NewMockController creates a new mock instance.
func NewMockController(ctrl *gomock.Controller) *MockController {
mock := &MockController{ctrl: ctrl}
mock.recorder = &MockControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockController) EXPECT() *MockControllerMockRecorder {
return m.recorder
}
// GetOIDCValidationConfig mocks base method.
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
ret0, _ := ret[0].(OIDCValidationConfig)
return ret0
}
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig))
}
// GetProxiesForCluster mocks base method.
func (m *MockController) GetProxiesForCluster(clusterAddr string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
ret0, _ := ret[0].([]string)
return ret0
}
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
func (mr *MockControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockController)(nil).GetProxiesForCluster), clusterAddr)
}
// RegisterProxyToCluster mocks base method.
func (m *MockController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
func (mr *MockControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
}
// SendServiceUpdateToCluster mocks base method.
func (m *MockController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SendServiceUpdateToCluster", ctx, accountID, update, clusterAddr)
}
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
func (mr *MockControllerMockRecorder) SendServiceUpdateToCluster(ctx, accountID, update, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockController)(nil).SendServiceUpdateToCluster), ctx, accountID, update, clusterAddr)
}
// UnregisterProxyFromCluster mocks base method.
func (m *MockController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
func (mr *MockControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
}

View File

@@ -0,0 +1,20 @@
package proxy
import "time"
// Proxy represents a reverse proxy instance
type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (Proxy) TableName() string {
return "proxies"
}

View File

@@ -1,463 +0,0 @@
package reverseproxy
import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/util/crypt"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Operation string
const (
Create Operation = "create"
Update Operation = "update"
Delete Operation = "delete"
)
type ProxyStatus string
const (
StatusPending ProxyStatus = "pending"
StatusActive ProxyStatus = "active"
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
StatusCertificatePending ProxyStatus = "certificate_pending"
StatusCertificateFailed ProxyStatus = "certificate_failed"
StatusError ProxyStatus = "error"
TargetTypePeer = "peer"
TargetTypeHost = "host"
TargetTypeDomain = "domain"
TargetTypeSubnet = "subnet"
)
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
}
type PasswordAuthConfig struct {
Enabled bool `json:"enabled"`
Password string `json:"password"`
}
type PINAuthConfig struct {
Enabled bool `json:"enabled"`
Pin string `json:"pin"`
}
type BearerAuthConfig struct {
Enabled bool `json:"enabled"`
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
}
type AuthConfig struct {
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
}
func (a *AuthConfig) HashSecrets() error {
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
a.PasswordAuth.Password = hashedPassword
}
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
if err != nil {
return fmt.Errorf("hash pin: %w", err)
}
a.PinAuth.Pin = hashedPin
}
return nil
}
func (a *AuthConfig) ClearSecrets() {
if a.PasswordAuth != nil {
a.PasswordAuth.Password = ""
}
if a.PinAuth != nil {
a.PinAuth.Pin = ""
}
}
type OIDCValidationConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
type ServiceMeta struct {
CreatedAt time.Time
CertificateIssuedAt time.Time
Status string
}
type Service struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
}
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
for _, target := range targets {
target.AccountID = accountID
}
s := &Service{
AccountID: accountID,
Name: name,
Domain: domain,
ProxyCluster: proxyCluster,
Targets: targets,
Enabled: enabled,
}
s.InitNewRecord()
return s
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
// Service record. This overwrites any existing ID and Meta fields and should
// only be called during initial creation, not for updates.
func (s *Service) InitNewRecord() {
s.ID = xid.New().String()
s.Meta = ServiceMeta{
CreatedAt: time.Now(),
Status: string(StatusPending),
}
}
func (s *Service) ToAPIResponse() *api.Service {
s.Auth.ClearSecrets()
authConfig := api.ServiceAuthConfig{}
if s.Auth.PasswordAuth != nil {
authConfig.PasswordAuth = &api.PasswordAuthConfig{
Enabled: s.Auth.PasswordAuth.Enabled,
Password: s.Auth.PasswordAuth.Password,
}
}
if s.Auth.PinAuth != nil {
authConfig.PinAuth = &api.PINAuthConfig{
Enabled: s.Auth.PinAuth.Enabled,
Pin: s.Auth.PinAuth.Pin,
}
}
if s.Auth.BearerAuth != nil {
authConfig.BearerAuth = &api.BearerAuthConfig{
Enabled: s.Auth.BearerAuth.Enabled,
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
}
}
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
apiTargets = append(apiTargets, api.ServiceTarget{
Path: target.Path,
Host: &target.Host,
Port: target.Port,
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
})
}
meta := api.ServiceMeta{
CreatedAt: s.Meta.CreatedAt,
Status: api.ServiceMetaStatus(s.Meta.Status),
}
if !s.Meta.CertificateIssuedAt.IsZero() {
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
}
resp := &api.Service{
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
}
if s.ProxyCluster != "" {
resp.ProxyCluster = &s.ProxyCluster
}
return resp
}
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {
continue
}
// TODO: Make path prefix stripping configurable per-target.
// Currently the matching prefix is baked into the target URL path,
// so the proxy strips-then-re-adds it (effectively a no-op).
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: "/", // TODO: support service path
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
path := "/"
if target.Path != nil {
path = *target.Path
}
pathMappings = append(pathMappings, &proto.PathMapping{
Path: path,
Target: targetURL.String(),
})
}
auth := &proto.Authentication{
SessionKey: s.SessionPublicKey,
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
}
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
auth.Password = true
}
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
auth.Pin = true
}
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
auth.Oidc = true
}
return &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: s.ID,
Domain: s.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: s.AccountID,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
}
}
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
switch op {
case Create:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
case Update:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
case Delete:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
default:
log.Fatalf("unknown operation type: %v", op)
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
}
}
// isDefaultPort reports whether port is the standard default for the given scheme
// (443 for https, 80 for http).
func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets {
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
Port: apiTarget.Port,
Protocol: string(apiTarget.Protocol),
TargetId: apiTarget.TargetId,
TargetType: string(apiTarget.TargetType),
Enabled: apiTarget.Enabled,
}
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
targets = append(targets, target)
}
s.Targets = targets
s.Enabled = req.Enabled
if req.PassHostHeader != nil {
s.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
s.RewriteRedirects = *req.RewriteRedirects
}
if req.Auth.PasswordAuth != nil {
s.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: req.Auth.PasswordAuth.Enabled,
Password: req.Auth.PasswordAuth.Password,
}
}
if req.Auth.PinAuth != nil {
s.Auth.PinAuth = &PINAuthConfig{
Enabled: req.Auth.PinAuth.Enabled,
Pin: req.Auth.PinAuth.Pin,
}
}
if req.Auth.BearerAuth != nil {
bearerAuth := &BearerAuthConfig{
Enabled: req.Auth.BearerAuth.Enabled,
}
if req.Auth.BearerAuth.DistributionGroups != nil {
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
}
s.Auth.BearerAuth = bearerAuth
}
}
func (s *Service) Validate() error {
if s.Name == "" {
return errors.New("service name is required")
}
if len(s.Name) > 255 {
return errors.New("service name exceeds maximum length of 255 characters")
}
if s.Domain == "" {
return errors.New("service domain is required")
}
if len(s.Targets) == 0 {
return errors.New("at least one target is required")
}
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// host field will be ignored
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
}
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
}
return nil
}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
}
func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
targets[i] = &targetCopy
}
return &Service{
ID: s.ID,
AccountID: s.AccountID,
Name: s.Name,
Domain: s.Domain,
ProxyCluster: s.ProxyCluster,
Targets: targets,
Enabled: s.Enabled,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: s.Auth,
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,
}
}
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,405 +0,0 @@
package reverseproxy
import (
"errors"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
)
func validProxy() *Service {
return &Service{
Name: "test",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
},
}
}
func TestValidate_Valid(t *testing.T) {
require.NoError(t, validProxy().Validate())
}
func TestValidate_EmptyName(t *testing.T) {
rp := validProxy()
rp.Name = ""
assert.ErrorContains(t, rp.Validate(), "name is required")
}
func TestValidate_EmptyDomain(t *testing.T) {
rp := validProxy()
rp.Domain = ""
assert.ErrorContains(t, rp.Validate(), "domain is required")
}
func TestValidate_NoTargets(t *testing.T) {
rp := validProxy()
rp.Targets = nil
assert.ErrorContains(t, rp.Validate(), "at least one target")
}
func TestValidate_EmptyTargetId(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetId = ""
assert.ErrorContains(t, rp.Validate(), "empty target_id")
}
func TestValidate_InvalidTargetType(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetType = "invalid"
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
}
func TestValidate_ResourceTarget(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "resource-1",
TargetType: TargetTypeHost,
Host: "example.org",
Port: 443,
Protocol: "https",
Enabled: true,
})
require.NoError(t, rp.Validate())
}
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "",
TargetType: TargetTypePeer,
Host: "10.0.0.2",
Port: 80,
Protocol: "http",
Enabled: true,
})
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "target 1")
assert.Contains(t, err.Error(), "empty target_id")
}
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string
port int
want bool
}{
{"http", 80, true},
{"https", 443, true},
{"http", 443, false},
{"https", 80, false},
{"http", 8080, false},
{"https", 8443, false},
{"http", 0, false},
{"https", 0, false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
})
}
}
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
oidcConfig := OIDCValidationConfig{}
tests := []struct {
name string
protocol string
host string
port int
wantTarget string
}{
{
name: "http with default port 80 omits port",
protocol: "http",
host: "10.0.0.1",
port: 80,
wantTarget: "http://10.0.0.1/",
},
{
name: "https with default port 443 omits port",
protocol: "https",
host: "10.0.0.1",
port: 443,
wantTarget: "https://10.0.0.1/",
},
{
name: "port 0 omits port",
protocol: "http",
host: "10.0.0.1",
port: 0,
wantTarget: "http://10.0.0.1/",
},
{
name: "non-default port is included",
protocol: "http",
host: "10.0.0.1",
port: 8080,
wantTarget: "http://10.0.0.1:8080/",
},
{
name: "https with non-default port is included",
protocol: "https",
host: "10.0.0.1",
port: 8443,
wantTarget: "https://10.0.0.1:8443/",
},
{
name: "http port 443 is included",
protocol: "http",
host: "10.0.0.1",
port: 443,
wantTarget: "http://10.0.0.1:443/",
},
{
name: "https port 80 is included",
protocol: "https",
host: "10.0.0.1",
port: 80,
wantTarget: "https://10.0.0.1:80/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: tt.host,
Port: tt.port,
Protocol: tt.protocol,
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
require.Len(t, pm.Path, 1, "should have one path mapping")
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
})
}
}
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
},
}
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
}
func TestToProtoMapping_OperationTypes(t *testing.T) {
rp := validProxy()
tests := []struct {
op Operation
want proto.ProxyMappingUpdateType
}{
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
}
for _, tt := range tests {
t.Run(string(tt.op), func(t *testing.T) {
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
assert.Equal(t, tt.want, pm.Type)
})
}
}
func TestAuthConfig_HashSecrets(t *testing.T) {
tests := []struct {
name string
config *AuthConfig
wantErr bool
validate func(*testing.T, *AuthConfig)
}{
{
name: "hash password successfully",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "testPassword123",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
}
// Verify the hash can be verified
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash PIN successfully",
config: &AuthConfig{
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "123456",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
}
// Verify the hash can be verified
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash both password and PIN",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "password",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "9999",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id")
}
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
t.Errorf("Password hash verification failed: %v", err)
}
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
t.Errorf("PIN hash verification failed: %v", err)
}
},
},
{
name: "skip disabled password auth",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: false,
Password: "password",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "password" {
t.Errorf("Disabled password auth should not be hashed")
}
},
},
{
name: "skip empty password",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "" {
t.Errorf("Empty password should remain empty")
}
},
},
{
name: "skip nil password auth",
config: &AuthConfig{
PasswordAuth: nil,
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "1234",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth != nil {
t.Errorf("PasswordAuth should remain nil")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN should still be hashed")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.HashSecrets()
if (err != nil) != tt.wantErr {
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.validate != nil {
tt.validate(t, tt.config)
}
})
}
}
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "correctPassword",
},
}
if err := config.HashSecrets(); err != nil {
t.Fatalf("HashSecrets() error = %v", err)
}
// Verify with wrong password should fail
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
}
}
func TestAuthConfig_ClearSecrets(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "hashedPassword",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "hashedPin",
},
}
config.ClearSecrets()
if config.PasswordAuth.Password != "" {
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
}
if config.PinAuth.Pin != "" {
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
}
}

View File

@@ -1,6 +1,6 @@
package reverseproxy package service
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod //go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import ( import (
"context" "context"
@@ -12,12 +12,17 @@ type Manager interface {
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
DeleteService(ctx context.Context, accountID, userID, serviceID string) error DeleteService(ctx context.Context, accountID, userID, serviceID string) error
DeleteAllServices(ctx context.Context, accountID, userID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error) GetGlobalServices(ctx context.Context) ([]*Service, error)
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error)
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StartExposeReaper(ctx context.Context)
} }

View File

@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go // Source: ./interface.go
// Package reverseproxy is a generated GoMock package. // Package service is a generated GoMock package.
package reverseproxy package service
import ( import (
context "context" context "context"
@@ -49,6 +49,35 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
} }
// CreateServiceFromPeer mocks base method.
func (m *MockManager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateServiceFromPeer", ctx, accountID, peerID, req)
ret0, _ := ret[0].(*ExposeServiceResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateServiceFromPeer indicates an expected call of CreateServiceFromPeer.
func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, req interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
}
// DeleteAllServices mocks base method.
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAllServices indicates an expected call of DeleteAllServices.
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
}
// DeleteService mocks base method. // DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -181,6 +210,20 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
} }
// RenewServiceFromPeer mocks base method.
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer.
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain)
}
// SetCertificateIssuedAt mocks base method. // SetCertificateIssuedAt mocks base method.
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -196,7 +239,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
} }
// SetStatus mocks base method. // SetStatus mocks base method.
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error { func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status) ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -209,6 +252,32 @@ func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status i
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
} }
// StartExposeReaper mocks base method.
func (m *MockManager) StartExposeReaper(ctx context.Context) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "StartExposeReaper", ctx)
}
// StartExposeReaper indicates an expected call of StartExposeReaper.
func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartExposeReaper", reflect.TypeOf((*MockManager)(nil).StartExposeReaper), ctx)
}
// StopServiceFromPeer mocks base method.
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// StopServiceFromPeer indicates an expected call of StopServiceFromPeer.
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain)
}
// UpdateService mocks base method. // UpdateService mocks base method.
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) { func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -6,10 +6,10 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
@@ -17,11 +17,11 @@ import (
) )
type handler struct { type handler struct {
manager reverseproxy.Manager manager rpservice.Manager
} }
// RegisterEndpoints registers all service HTTP endpoints. // RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{ h := &handler{
manager: manager, manager: manager,
} }
@@ -72,8 +72,11 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
return return
} }
service := new(reverseproxy.Service) service := new(rpservice.Service)
service.FromAPIRequest(&req, userAuth.AccountId) if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err = service.Validate(); err != nil { if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
@@ -130,9 +133,12 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
return return
} }
service := new(reverseproxy.Service) service := new(rpservice.Service)
service.ID = serviceID service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId) if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err = service.Validate(); err != nil { if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)

View File

@@ -0,0 +1,65 @@
package manager
import (
"context"
"math/rand/v2"
"time"
"github.com/netbirdio/netbird/shared/management/status"
log "github.com/sirupsen/logrus"
)
const (
exposeTTL = 90 * time.Second
exposeReapInterval = 30 * time.Second
maxExposesPerPeer = 10
exposeReapBatch = 100
)
type exposeReaper struct {
manager *Manager
}
// StartExposeReaper starts a background goroutine that reaps expired ephemeral services from the DB.
func (r *exposeReaper) StartExposeReaper(ctx context.Context) {
go func() {
// start with a random delay
rn := rand.IntN(10)
time.Sleep(time.Duration(rn) * time.Second)
ticker := time.NewTicker(exposeReapInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
r.reapExpiredExposes(ctx)
}
}
}()
}
func (r *exposeReaper) reapExpiredExposes(ctx context.Context) {
expired, err := r.manager.store.GetExpiredEphemeralServices(ctx, exposeTTL, exposeReapBatch)
if err != nil {
log.Errorf("failed to get expired ephemeral services: %v", err)
return
}
for _, svc := range expired {
log.Infof("reaping expired expose session for peer %s, domain %s", svc.SourcePeer, svc.Domain)
err := r.manager.deleteExpiredPeerService(ctx, svc.AccountID, svc.SourcePeer, svc.ID)
if err == nil {
continue
}
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
log.Debugf("service %s was already deleted by another instance", svc.Domain)
} else {
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", svc.Domain, err)
}
}
}

View File

@@ -0,0 +1,208 @@
package manager
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/store"
)
func TestReapExpiredExposes(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
// Manually expire the service by backdating meta_last_renewed_at
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Create a non-expired service
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8081,
Protocol: "http",
})
require.NoError(t, err)
mgr.exposeReaper.reapExpiredExposes(ctx)
// Expired service should be deleted
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "expired service should be deleted")
// Non-expired service should remain
_, err = testStore.GetServiceByDomain(ctx, resp2.Domain)
require.NoError(t, err, "active service should remain")
}
func TestReapAlreadyDeletedService(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Delete the service before reaping
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
// Reaping should handle the already-deleted service gracefully
mgr.exposeReaper.reapExpiredExposes(ctx)
}
func TestConcurrentReapAndRenew(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
require.NoError(t, err)
}
// Expire all services
services, err := testStore.GetAccountServices(ctx, store.LockingStrengthNone, testAccountID)
require.NoError(t, err)
for _, svc := range services {
if svc.Source == rpservice.SourceEphemeral {
expireEphemeralService(t, testStore, testAccountID, svc.Domain)
}
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
mgr.exposeReaper.reapExpiredExposes(ctx)
}()
go func() {
defer wg.Done()
_, _ = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
}()
wg.Wait()
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count, "all expired services should be reaped")
}
func TestRenewEphemeralService(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
t.Run("renew succeeds for active service", func(t *testing.T) {
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8082,
Protocol: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
})
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
require.Error(t, err)
assert.Contains(t, err.Error(), "no active expose session")
})
}
func TestCountAndExistsEphemeralServices(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8083,
Protocol: "http",
})
require.NoError(t, err)
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
exists, err := mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
assert.True(t, exists, "service should exist")
exists, err = mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, "no-such.domain")
require.NoError(t, err)
assert.False(t, exists, "non-existent service should not exist")
}
func TestMaxExposesPerPeerEnforced(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
for i := range maxExposesPerPeer {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8090 + i,
Protocol: "http",
})
require.NoError(t, err, "expose %d should succeed", i)
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9999,
Protocol: "http",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
}
func TestReapSkipsRenewedService(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8086,
Protocol: "http",
})
require.NoError(t, err)
// Expire the service
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Renew it before the reaper runs
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
// Reaper should skip it because the re-check sees a fresh timestamp
mgr.exposeReaper.reapExpiredExposes(ctx)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err, "renewed service should survive reaping")
}
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
t.Helper()
svc, err := s.GetServiceByDomain(context.Background(), domain)
require.NoError(t, err)
expired := time.Now().Add(-2 * exposeTTL)
svc.Meta.LastRenewedAt = &expired
err = s.UpdateService(context.Background(), svc)
require.NoError(t, err)
}

View File

@@ -0,0 +1,928 @@
package manager
import (
"context"
"fmt"
"math/rand/v2"
"slices"
"time"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
GetClusterDomains() []string
}
type Manager struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyController proxy.Controller
clusterDeriver ClusterDeriver
exposeReaper *exposeReaper
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
mgr := &Manager{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyController: proxyController,
clusterDeriver: clusterDeriver,
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr
}
// StartExposeReaper starts the background goroutine that reaps expired ephemeral services.
func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeReaper.StartExposeReaper(ctx)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
for _, target := range s.Targets {
switch target.TargetType {
case service.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = peer.IP.String()
case service.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Prefix.Addr().String()
case service.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Domain
case service.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
return nil, err
}
if err := m.persistNewService(ctx, accountID, s); err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return s, nil
}
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
if m.clusterDeriver != nil {
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
service.ProxyCluster = proxyCluster
}
service.AccountID = accountID
service.InitNewRecord()
if err := service.Auth.HashSecrets(); err != nil {
return fmt.Errorf("hash secrets: %w", err)
}
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate session keys: %w", err)
}
service.SessionPrivateKey = keyPair.PrivateKey
service.SessionPublicKey = keyPair.PublicKey
return nil
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
})
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Lock the peer row to serialize concurrent creates for the same peer.
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
// table returns no rows and acquires no locks, allowing concurrent inserts
// to bypass the per-peer limit.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
return nil
})
}
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
}
return nil
}
if existingService != nil && existingService.ID != excludeServiceID {
return status.Errorf(status.AlreadyExists, "domain already taken")
}
return nil
}
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := service.Auth.HashSecrets(); err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
type serviceUpdateInfo struct {
oldCluster string
domainChanged bool
serviceEnabledChanged bool
}
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
m.preserveExistingAuthSecrets(service, existingService)
m.preserveServiceMetadata(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
})
return &updateInfo, err
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
} else {
service.ProxyCluster = newCluster
}
}
return nil
}
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
}
}
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyController.GetOIDCValidationConfig()
switch {
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
case !s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
case s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
default:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
}
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
for _, target := range targets {
switch target.TargetType {
case service.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
}
}
return nil
}
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var s *service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
return nil
})
if err != nil {
return err
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var services []*service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return err
}
for _, svc := range services {
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
}
return nil
})
if err != nil {
return err
}
oidcCfg := m.proxyController.GetOIDCValidationConfig()
for _, svc := range services {
m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
now := time.Now()
service.Meta.CertificateIssuedAt = &now
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
}
return nil
})
}
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
service.Meta.Status = string(status)
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service status: %w", err)
}
return nil
})
}
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get services: %w", err)
}
for _, s := range services {
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
}
return nil
}
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, service.AccountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return "", nil
}
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
}
if target == nil {
return "", nil
}
return target.ServiceID, nil
}
// validateExposePermission checks whether the peer is allowed to use the expose feature.
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return status.Errorf(status.Internal, "get account settings: %v", err)
}
if !settings.PeerExposeEnabled {
return status.Errorf(status.PermissionDenied, "peer expose is not enabled for this account")
}
if len(settings.PeerExposeGroups) == 0 {
return status.Errorf(status.PermissionDenied, "no group is set for peer expose")
}
peerGroupIDs, err := m.store.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer group IDs: %v", err)
return status.Errorf(status.Internal, "get peer groups: %v", err)
}
for _, pg := range peerGroupIDs {
if slices.Contains(settings.PeerExposeGroups, pg) {
return nil
}
}
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
}
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
if err := req.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
}
if err := m.validateExposePermission(ctx, accountID, peerID); err != nil {
return nil, err
}
serviceName, err := service.GenerateExposeName(req.NamePrefix)
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
}
svc := req.ToService(accountID, peerID, serviceName)
svc.Source = service.SourceEphemeral
if svc.Domain == "" {
domain, err := m.buildRandomDomain(svc.Name)
if err != nil {
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
}
svc.Domain = domain
}
if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled {
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups)
if err != nil {
return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err)
}
svc.Auth.BearerAuth.DistributionGroups = groupIDs
}
if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil {
return nil, err
}
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, err
}
svc.SourcePeer = peerID
now := time.Now()
svc.Meta.LastRenewedAt = &now
if err := m.persistNewEphemeralService(ctx, accountID, peerID, svc); err != nil {
return nil, err
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta)
if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil {
return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return &service.ExposeServiceResponse{
ServiceName: svc.Name,
ServiceURL: "https://" + svc.Domain,
Domain: svc.Domain,
}, nil
}
func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
if len(groupNames) == 0 {
return []string{}, fmt.Errorf("no group names provided")
}
groupIDs := make([]string, 0, len(groupNames))
for _, groupName := range groupNames {
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
}
groupIDs = append(groupIDs, g.ID)
}
return groupIDs, nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
}
clusterDomains := m.clusterDeriver.GetClusterDomains()
if len(clusterDomains) == 0 {
return "", fmt.Errorf("no cluster domains found for service %s", name)
}
index := rand.IntN(len(clusterDomains))
domain := name + "." + clusterDomains[index]
return domain, nil
}
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
}
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
return err
}
return nil
}
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
activityCode := activity.PeerServiceUnexposed
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, domain)
if err != nil {
return nil, err
}
if svc.Source != service.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if svc.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return svc, nil
}
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
var svc *service.Service
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose")
}
if svc.SourcePeer != peerID {
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
return nil
})
if err != nil {
return err
}
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
peer = nil
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// deleteExpiredPeerService deletes an ephemeral service by ID after re-checking
// that it is still expired under a row lock. This prevents deleting a service
// that was renewed between the batch query and this delete, and ensures only one
// management instance processes the deletion
func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerID, serviceID string) error {
var svc *service.Service
deleted := false
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if svc.Source != service.SourceEphemeral || svc.SourcePeer != peerID {
return status.Errorf(status.PermissionDenied, "service does not match expected ephemeral owner")
}
if svc.Meta.LastRenewedAt != nil && time.Since(*svc.Meta.LastRenewedAt) <= exposeTTL {
return nil
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
deleted = true
return nil
})
if err != nil {
return err
}
if !deleted {
return nil
}
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
peer = nil
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any {
if peer == nil {
return meta
}
meta["peer_name"] = peer.Name
if peer.IP != nil {
meta["peer_ip"] = peer.IP.String()
}
return meta
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,817 @@
package service
import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/util/crypt"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Operation string
const (
Create Operation = "create"
Update Operation = "update"
Delete Operation = "delete"
)
type Status string
const (
StatusPending Status = "pending"
StatusActive Status = "active"
StatusTunnelNotCreated Status = "tunnel_not_created"
StatusCertificatePending Status = "certificate_pending"
StatusCertificateFailed Status = "certificate_failed"
StatusError Status = "error"
TargetTypePeer = "peer"
TargetTypeHost = "host"
TargetTypeDomain = "domain"
TargetTypeSubnet = "subnet"
SourcePermanent = "permanent"
SourceEphemeral = "ephemeral"
)
type TargetOptions struct {
SkipTLSVerify bool `json:"skip_tls_verify"`
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
}
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
Options TargetOptions `gorm:"embedded" json:"options"`
}
type PasswordAuthConfig struct {
Enabled bool `json:"enabled"`
Password string `json:"password"`
}
type PINAuthConfig struct {
Enabled bool `json:"enabled"`
Pin string `json:"pin"`
}
type BearerAuthConfig struct {
Enabled bool `json:"enabled"`
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
}
type AuthConfig struct {
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
}
func (a *AuthConfig) HashSecrets() error {
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
a.PasswordAuth.Password = hashedPassword
}
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
if err != nil {
return fmt.Errorf("hash pin: %w", err)
}
a.PinAuth.Pin = hashedPin
}
return nil
}
func (a *AuthConfig) ClearSecrets() {
if a.PasswordAuth != nil {
a.PasswordAuth.Password = ""
}
if a.PinAuth != nil {
a.PinAuth.Pin = ""
}
}
type Meta struct {
CreatedAt time.Time
CertificateIssuedAt *time.Time
Status string
LastRenewedAt *time.Time
}
type Service struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"type:varchar(255);uniqueIndex"`
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
SourcePeer string `gorm:"index:idx_service_source_peer"`
}
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
for _, target := range targets {
target.AccountID = accountID
}
s := &Service{
AccountID: accountID,
Name: name,
Domain: domain,
ProxyCluster: proxyCluster,
Targets: targets,
Enabled: enabled,
}
s.InitNewRecord()
return s
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
// Service record. This overwrites any existing ID and Meta fields and should
// only be called during initial creation, not for updates.
func (s *Service) InitNewRecord() {
s.ID = xid.New().String()
s.Meta = Meta{
CreatedAt: time.Now(),
Status: string(StatusPending),
}
}
func (s *Service) ToAPIResponse() *api.Service {
s.Auth.ClearSecrets()
authConfig := api.ServiceAuthConfig{}
if s.Auth.PasswordAuth != nil {
authConfig.PasswordAuth = &api.PasswordAuthConfig{
Enabled: s.Auth.PasswordAuth.Enabled,
Password: s.Auth.PasswordAuth.Password,
}
}
if s.Auth.PinAuth != nil {
authConfig.PinAuth = &api.PINAuthConfig{
Enabled: s.Auth.PinAuth.Enabled,
Pin: s.Auth.PinAuth.Pin,
}
}
if s.Auth.BearerAuth != nil {
authConfig.BearerAuth = &api.BearerAuthConfig{
Enabled: s.Auth.BearerAuth.Enabled,
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
}
}
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
st := api.ServiceTarget{
Path: target.Path,
Host: &target.Host,
Port: target.Port,
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
}
st.Options = targetOptionsToAPI(target.Options)
apiTargets = append(apiTargets, st)
}
meta := api.ServiceMeta{
CreatedAt: s.Meta.CreatedAt,
Status: api.ServiceMetaStatus(s.Meta.Status),
}
if s.Meta.CertificateIssuedAt != nil {
meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt
}
resp := &api.Service{
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
}
if s.ProxyCluster != "" {
resp.ProxyCluster = &s.ProxyCluster
}
return resp
}
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {
continue
}
// TODO: Make path prefix stripping configurable per-target.
// Currently the matching prefix is baked into the target URL path,
// so the proxy strips-then-re-adds it (effectively a no-op).
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: "/", // TODO: support service path
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
path := "/"
if target.Path != nil {
path = *target.Path
}
pm := &proto.PathMapping{
Path: path,
Target: targetURL.String(),
}
pm.Options = targetOptionsToProto(target.Options)
pathMappings = append(pathMappings, pm)
}
auth := &proto.Authentication{
SessionKey: s.SessionPublicKey,
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
}
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
auth.Password = true
}
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
auth.Pin = true
}
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
auth.Oidc = true
}
return &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: s.ID,
Domain: s.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: s.AccountID,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
}
}
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
switch op {
case Create:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
case Update:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
case Delete:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
default:
log.Fatalf("unknown operation type: %v", op)
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
}
}
// isDefaultPort reports whether port is the standard default for the given scheme
// (443 for https, 80 for http).
func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
// PathRewriteMode controls how the request path is rewritten before forwarding.
type PathRewriteMode string
const (
PathRewritePreserve PathRewriteMode = "preserve"
)
func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
switch mode {
case PathRewritePreserve:
return proto.PathRewriteMode_PATH_REWRITE_PRESERVE
default:
return proto.PathRewriteMode_PATH_REWRITE_DEFAULT
}
}
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
return nil
}
apiOpts := &api.ServiceTargetOptions{}
if opts.SkipTLSVerify {
apiOpts.SkipTlsVerify = &opts.SkipTLSVerify
}
if opts.RequestTimeout != 0 {
s := opts.RequestTimeout.String()
apiOpts.RequestTimeout = &s
}
if opts.PathRewrite != "" {
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
apiOpts.PathRewrite = &pr
}
if len(opts.CustomHeaders) > 0 {
apiOpts.CustomHeaders = &opts.CustomHeaders
}
return apiOpts
}
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
return nil
}
popts := &proto.PathTargetOptions{
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
}
if opts.RequestTimeout != 0 {
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
}
return popts
}
func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
var opts TargetOptions
if o.SkipTlsVerify != nil {
opts.SkipTLSVerify = *o.SkipTlsVerify
}
if o.RequestTimeout != nil {
d, err := time.ParseDuration(*o.RequestTimeout)
if err != nil {
return opts, fmt.Errorf("target %d: parse request_timeout %q: %w", idx, *o.RequestTimeout, err)
}
opts.RequestTimeout = d
}
if o.PathRewrite != nil {
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
}
if o.CustomHeaders != nil {
opts.CustomHeaders = *o.CustomHeaders
}
return opts, nil
}
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) error {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
for i, apiTarget := range req.Targets {
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
Port: apiTarget.Port,
Protocol: string(apiTarget.Protocol),
TargetId: apiTarget.TargetId,
TargetType: string(apiTarget.TargetType),
Enabled: apiTarget.Enabled,
}
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
if apiTarget.Options != nil {
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
if err != nil {
return err
}
target.Options = opts
}
targets = append(targets, target)
}
s.Targets = targets
s.Enabled = req.Enabled
if req.PassHostHeader != nil {
s.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
s.RewriteRedirects = *req.RewriteRedirects
}
if req.Auth.PasswordAuth != nil {
s.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: req.Auth.PasswordAuth.Enabled,
Password: req.Auth.PasswordAuth.Password,
}
}
if req.Auth.PinAuth != nil {
s.Auth.PinAuth = &PINAuthConfig{
Enabled: req.Auth.PinAuth.Enabled,
Pin: req.Auth.PinAuth.Pin,
}
}
if req.Auth.BearerAuth != nil {
bearerAuth := &BearerAuthConfig{
Enabled: req.Auth.BearerAuth.Enabled,
}
if req.Auth.BearerAuth.DistributionGroups != nil {
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
}
s.Auth.BearerAuth = bearerAuth
}
return nil
}
func (s *Service) Validate() error {
if s.Name == "" {
return errors.New("service name is required")
}
if len(s.Name) > 255 {
return errors.New("service name exceeds maximum length of 255 characters")
}
if s.Domain == "" {
return errors.New("service domain is required")
}
if len(s.Targets) == 0 {
return errors.New("at least one target is required")
}
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// host field will be ignored
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
}
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
if err := validateTargetOptions(i, &target.Options); err != nil {
return err
}
}
return nil
}
const (
maxRequestTimeout = 5 * time.Minute
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
)
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
var httpHeaderNameRe = regexp.MustCompile(`^[!#$%&'*+\-.^_` + "`" + `|~0-9A-Za-z]+$`)
// hopByHopHeaders are headers that must not be set as custom headers
// because they are connection-level and stripped by the proxy.
var hopByHopHeaders = map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Proxy-Connection": {},
"Te": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
}
// reservedHeaders are set authoritatively by the proxy or control HTTP framing
// and cannot be overridden.
var reservedHeaders = map[string]struct{}{
"Content-Length": {},
"Content-Type": {},
"Cookie": {},
"Forwarded": {},
"X-Forwarded-For": {},
"X-Forwarded-Host": {},
"X-Forwarded-Port": {},
"X-Forwarded-Proto": {},
"X-Real-Ip": {},
}
func validateTargetOptions(idx int, opts *TargetOptions) error {
if opts.PathRewrite != "" && opts.PathRewrite != PathRewritePreserve {
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
}
if opts.RequestTimeout != 0 {
if opts.RequestTimeout <= 0 {
return fmt.Errorf("target %d: request_timeout must be positive", idx)
}
if opts.RequestTimeout > maxRequestTimeout {
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
}
}
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
return err
}
return nil
}
func validateCustomHeaders(idx int, headers map[string]string) error {
if len(headers) > maxCustomHeaders {
return fmt.Errorf("target %d: custom_headers count %d exceeds maximum of %d", idx, len(headers), maxCustomHeaders)
}
seen := make(map[string]string, len(headers))
for key, value := range headers {
if !httpHeaderNameRe.MatchString(key) {
return fmt.Errorf("target %d: custom header key %q is not a valid HTTP header name", idx, key)
}
if len(key) > maxHeaderKeyLen {
return fmt.Errorf("target %d: custom header key %q exceeds maximum length of %d", idx, key, maxHeaderKeyLen)
}
if len(value) > maxHeaderValueLen {
return fmt.Errorf("target %d: custom header %q value exceeds maximum length of %d", idx, key, maxHeaderValueLen)
}
if containsCRLF(key) || containsCRLF(value) {
return fmt.Errorf("target %d: custom header %q contains invalid characters", idx, key)
}
canonical := http.CanonicalHeaderKey(key)
if prev, ok := seen[canonical]; ok {
return fmt.Errorf("target %d: custom header keys %q and %q collide (both canonicalize to %q)", idx, prev, key, canonical)
}
seen[canonical] = key
if _, ok := hopByHopHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is a hop-by-hop header and cannot be set", idx, key)
}
if _, ok := reservedHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is managed by the proxy and cannot be overridden", idx, key)
}
if canonical == "Host" {
return fmt.Errorf("target %d: use pass_host_header instead of setting Host as a custom header", idx)
}
}
return nil
}
func containsCRLF(s string) bool {
return strings.ContainsAny(s, "\r\n")
}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
}
func (s *Service) isAuthEnabled() bool {
return s.Auth.PasswordAuth != nil || s.Auth.PinAuth != nil || s.Auth.BearerAuth != nil
}
func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
if len(target.Options.CustomHeaders) > 0 {
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
for k, v := range target.Options.CustomHeaders {
targetCopy.Options.CustomHeaders[k] = v
}
}
targets[i] = &targetCopy
}
return &Service{
ID: s.ID,
AccountID: s.AccountID,
Name: s.Name,
Domain: s.Domain,
ProxyCluster: s.ProxyCluster,
Targets: targets,
Enabled: s.Enabled,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: s.Auth,
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,
Source: s.Source,
SourcePeer: s.SourcePeer,
}
}
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
const alphanumCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`)
// ExposeServiceRequest contains the parameters for creating a peer-initiated expose service.
type ExposeServiceRequest struct {
NamePrefix string
Port int
Protocol string
Domain string
Pin string
Password string
UserGroups []string
}
// Validate checks all fields of the expose request.
func (r *ExposeServiceRequest) Validate() error {
if r == nil {
return errors.New("request cannot be nil")
}
if r.Port < 1 || r.Port > 65535 {
return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port)
}
if r.Protocol != "http" && r.Protocol != "https" {
return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol)
}
if r.Pin != "" && !pinRegexp.MatchString(r.Pin) {
return errors.New("invalid pin: must be exactly 6 digits")
}
for _, g := range r.UserGroups {
if g == "" {
return errors.New("user group name cannot be empty")
}
}
if r.NamePrefix != "" && !validNamePrefix.MatchString(r.NamePrefix) {
return fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", r.NamePrefix)
}
return nil
}
// ToService builds a Service from the expose request.
func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service {
service := &Service{
AccountID: accountID,
Name: serviceName,
Enabled: true,
Targets: []*Target{
{
AccountID: accountID,
Port: r.Port,
Protocol: r.Protocol,
TargetId: peerID,
TargetType: TargetTypePeer,
Enabled: true,
},
},
}
if r.Domain != "" {
service.Domain = serviceName + "." + r.Domain
}
if r.Pin != "" {
service.Auth.PinAuth = &PINAuthConfig{
Enabled: true,
Pin: r.Pin,
}
}
if r.Password != "" {
service.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: true,
Password: r.Password,
}
}
if len(r.UserGroups) > 0 {
service.Auth.BearerAuth = &BearerAuthConfig{
Enabled: true,
DistributionGroups: r.UserGroups,
}
}
return service
}
// ExposeServiceResponse contains the result of a successful peer expose creation.
type ExposeServiceResponse struct {
ServiceName string
ServiceURL string
Domain string
}
// GenerateExposeName generates a random service name for peer-exposed services.
// The prefix, if provided, must be a valid DNS label component (lowercase alphanumeric and hyphens).
func GenerateExposeName(prefix string) (string, error) {
if prefix != "" && !validNamePrefix.MatchString(prefix) {
return "", fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", prefix)
}
suffixLen := 12
if prefix != "" {
suffixLen = 4
}
suffix, err := randomAlphanumeric(suffixLen)
if err != nil {
return "", fmt.Errorf("generate random name: %w", err)
}
if prefix == "" {
return suffix, nil
}
return prefix + "-" + suffix, nil
}
func randomAlphanumeric(n int) (string, error) {
result := make([]byte, n)
charsetLen := big.NewInt(int64(len(alphanumCharset)))
for i := range result {
idx, err := rand.Int(rand.Reader, charsetLen)
if err != nil {
return "", err
}
result[i] = alphanumCharset[idx.Int64()]
}
return string(result), nil
}

View File

@@ -0,0 +1,732 @@
package service
import (
"errors"
"fmt"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
)
func validProxy() *Service {
return &Service{
Name: "test",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
},
}
}
func TestValidate_Valid(t *testing.T) {
require.NoError(t, validProxy().Validate())
}
func TestValidate_EmptyName(t *testing.T) {
rp := validProxy()
rp.Name = ""
assert.ErrorContains(t, rp.Validate(), "name is required")
}
func TestValidate_EmptyDomain(t *testing.T) {
rp := validProxy()
rp.Domain = ""
assert.ErrorContains(t, rp.Validate(), "domain is required")
}
func TestValidate_NoTargets(t *testing.T) {
rp := validProxy()
rp.Targets = nil
assert.ErrorContains(t, rp.Validate(), "at least one target")
}
func TestValidate_EmptyTargetId(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetId = ""
assert.ErrorContains(t, rp.Validate(), "empty target_id")
}
func TestValidate_InvalidTargetType(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetType = "invalid"
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
}
func TestValidate_ResourceTarget(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "resource-1",
TargetType: TargetTypeHost,
Host: "example.org",
Port: 443,
Protocol: "https",
Enabled: true,
})
require.NoError(t, rp.Validate())
}
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "",
TargetType: TargetTypePeer,
Host: "10.0.0.2",
Port: 80,
Protocol: "http",
Enabled: true,
})
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "target 1")
assert.Contains(t, err.Error(), "empty target_id")
}
func TestValidateTargetOptions_PathRewrite(t *testing.T) {
tests := []struct {
name string
mode PathRewriteMode
wantErr string
}{
{"empty is default", "", ""},
{"preserve is valid", PathRewritePreserve, ""},
{"unknown rejected", "regex", "unknown path_rewrite mode"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.PathRewrite = tt.mode
err := rp.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
wantErr string
}{
{"valid 30s", 30 * time.Second, ""},
{"valid 2m", 2 * time.Minute, ""},
{"zero is fine", 0, ""},
{"negative", -1 * time.Second, "must be positive"},
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.RequestTimeout = tt.timeout
err := rp.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestValidateTargetOptions_CustomHeaders(t *testing.T) {
t.Run("valid headers", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{
"X-Custom": "value",
"X-Trace": "abc123",
}
assert.NoError(t, rp.Validate())
})
t.Run("CRLF in key", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Bad\r\nKey": "value"}
assert.ErrorContains(t, rp.Validate(), "not a valid HTTP header name")
})
t.Run("CRLF in value", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Good": "bad\nvalue"}
assert.ErrorContains(t, rp.Validate(), "invalid characters")
})
t.Run("hop-by-hop header rejected", func(t *testing.T) {
for _, h := range []string{"Connection", "Transfer-Encoding", "Keep-Alive", "Upgrade", "Proxy-Connection"} {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
assert.ErrorContains(t, rp.Validate(), "hop-by-hop", "header %q should be rejected", h)
}
})
t.Run("reserved header rejected", func(t *testing.T) {
for _, h := range []string{"X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Forwarded-Port", "Cookie", "Forwarded", "Content-Length", "Content-Type"} {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
assert.ErrorContains(t, rp.Validate(), "managed by the proxy", "header %q should be rejected", h)
}
})
t.Run("Host header rejected", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"Host": "evil.com"}
assert.ErrorContains(t, rp.Validate(), "pass_host_header")
})
t.Run("too many headers", func(t *testing.T) {
rp := validProxy()
headers := make(map[string]string, 17)
for i := range 17 {
headers[fmt.Sprintf("X-H%d", i)] = "v"
}
rp.Targets[0].Options.CustomHeaders = headers
assert.ErrorContains(t, rp.Validate(), "exceeds maximum of 16")
})
t.Run("key too long", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{strings.Repeat("X", 129): "v"}
assert.ErrorContains(t, rp.Validate(), "key")
assert.ErrorContains(t, rp.Validate(), "exceeds maximum length")
})
t.Run("value too long", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Ok": strings.Repeat("v", 4097)}
assert.ErrorContains(t, rp.Validate(), "value exceeds maximum length")
})
t.Run("duplicate canonical keys rejected", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{
"x-custom": "a",
"X-Custom": "b",
}
assert.ErrorContains(t, rp.Validate(), "collide")
})
}
func TestToProtoMapping_TargetOptions(t *testing.T) {
rp := &Service{
ID: "svc-1",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
Enabled: true,
Options: TargetOptions{
SkipTLSVerify: true,
RequestTimeout: 30 * time.Second,
PathRewrite: PathRewritePreserve,
CustomHeaders: map[string]string{"X-Custom": "val"},
},
},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
opts := pm.Path[0].Options
require.NotNil(t, opts, "options should be populated")
assert.True(t, opts.SkipTlsVerify)
assert.Equal(t, proto.PathRewriteMode_PATH_REWRITE_PRESERVE, opts.PathRewrite)
assert.Equal(t, map[string]string{"X-Custom": "val"}, opts.CustomHeaders)
require.NotNil(t, opts.RequestTimeout)
assert.Equal(t, int64(30), opts.RequestTimeout.Seconds)
}
func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
rp := &Service{
ID: "svc-1",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Nil(t, pm.Path[0].Options, "options should be nil when all defaults")
}
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string
port int
want bool
}{
{"http", 80, true},
{"https", 443, true},
{"http", 443, false},
{"https", 80, false},
{"http", 8080, false},
{"https", 8443, false},
{"http", 0, false},
{"https", 0, false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
})
}
}
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
oidcConfig := proxy.OIDCValidationConfig{}
tests := []struct {
name string
protocol string
host string
port int
wantTarget string
}{
{
name: "http with default port 80 omits port",
protocol: "http",
host: "10.0.0.1",
port: 80,
wantTarget: "http://10.0.0.1/",
},
{
name: "https with default port 443 omits port",
protocol: "https",
host: "10.0.0.1",
port: 443,
wantTarget: "https://10.0.0.1/",
},
{
name: "port 0 omits port",
protocol: "http",
host: "10.0.0.1",
port: 0,
wantTarget: "http://10.0.0.1/",
},
{
name: "non-default port is included",
protocol: "http",
host: "10.0.0.1",
port: 8080,
wantTarget: "http://10.0.0.1:8080/",
},
{
name: "https with non-default port is included",
protocol: "https",
host: "10.0.0.1",
port: 8443,
wantTarget: "https://10.0.0.1:8443/",
},
{
name: "http port 443 is included",
protocol: "http",
host: "10.0.0.1",
port: 443,
wantTarget: "http://10.0.0.1:443/",
},
{
name: "https port 80 is included",
protocol: "https",
host: "10.0.0.1",
port: 80,
wantTarget: "https://10.0.0.1:80/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: tt.host,
Port: tt.port,
Protocol: tt.protocol,
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
require.Len(t, pm.Path, 1, "should have one path mapping")
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
})
}
}
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
}
func TestToProtoMapping_OperationTypes(t *testing.T) {
rp := validProxy()
tests := []struct {
op Operation
want proto.ProxyMappingUpdateType
}{
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
}
for _, tt := range tests {
t.Run(string(tt.op), func(t *testing.T) {
pm := rp.ToProtoMapping(tt.op, "", proxy.OIDCValidationConfig{})
assert.Equal(t, tt.want, pm.Type)
})
}
}
func TestAuthConfig_HashSecrets(t *testing.T) {
tests := []struct {
name string
config *AuthConfig
wantErr bool
validate func(*testing.T, *AuthConfig)
}{
{
name: "hash password successfully",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "testPassword123",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
}
// Verify the hash can be verified
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash PIN successfully",
config: &AuthConfig{
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "123456",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
}
// Verify the hash can be verified
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash both password and PIN",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "password",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "9999",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id")
}
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
t.Errorf("Password hash verification failed: %v", err)
}
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
t.Errorf("PIN hash verification failed: %v", err)
}
},
},
{
name: "skip disabled password auth",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: false,
Password: "password",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "password" {
t.Errorf("Disabled password auth should not be hashed")
}
},
},
{
name: "skip empty password",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "" {
t.Errorf("Empty password should remain empty")
}
},
},
{
name: "skip nil password auth",
config: &AuthConfig{
PasswordAuth: nil,
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "1234",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth != nil {
t.Errorf("PasswordAuth should remain nil")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN should still be hashed")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.HashSecrets()
if (err != nil) != tt.wantErr {
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.validate != nil {
tt.validate(t, tt.config)
}
})
}
}
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "correctPassword",
},
}
if err := config.HashSecrets(); err != nil {
t.Fatalf("HashSecrets() error = %v", err)
}
// Verify with wrong password should fail
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
}
}
func TestAuthConfig_ClearSecrets(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "hashedPassword",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "hashedPin",
},
}
config.ClearSecrets()
if config.PasswordAuth.Password != "" {
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
}
if config.PinAuth.Pin != "" {
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
}
}
func TestGenerateExposeName(t *testing.T) {
t.Run("no prefix generates 12-char name", func(t *testing.T) {
name, err := GenerateExposeName("")
require.NoError(t, err)
assert.Len(t, name, 12)
assert.Regexp(t, `^[a-z0-9]+$`, name)
})
t.Run("with prefix generates prefix-XXXX", func(t *testing.T) {
name, err := GenerateExposeName("myapp")
require.NoError(t, err)
assert.True(t, strings.HasPrefix(name, "myapp-"), "name should start with prefix")
suffix := strings.TrimPrefix(name, "myapp-")
assert.Len(t, suffix, 4, "suffix should be 4 chars")
assert.Regexp(t, `^[a-z0-9]+$`, suffix)
})
t.Run("unique names", func(t *testing.T) {
names := make(map[string]bool)
for i := 0; i < 50; i++ {
name, err := GenerateExposeName("")
require.NoError(t, err)
names[name] = true
}
assert.Greater(t, len(names), 45, "should generate mostly unique names")
})
t.Run("valid prefixes", func(t *testing.T) {
validPrefixes := []string{"a", "ab", "a1", "my-app", "web-server-01", "a-b"}
for _, prefix := range validPrefixes {
name, err := GenerateExposeName(prefix)
assert.NoError(t, err, "prefix %q should be valid", prefix)
assert.True(t, strings.HasPrefix(name, prefix+"-"), "name should start with %q-", prefix)
}
})
t.Run("invalid prefixes", func(t *testing.T) {
invalidPrefixes := []string{
"-starts-with-dash",
"ends-with-dash-",
"has.dots",
"HAS-UPPER",
"has spaces",
"has/slash",
"a--",
}
for _, prefix := range invalidPrefixes {
_, err := GenerateExposeName(prefix)
assert.Error(t, err, "prefix %q should be invalid", prefix)
assert.Contains(t, err.Error(), "invalid name prefix")
}
})
}
func TestExposeServiceRequest_ToService(t *testing.T) {
t.Run("basic HTTP service", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
service := req.ToService("account-1", "peer-1", "mysvc")
assert.Equal(t, "account-1", service.AccountID)
assert.Equal(t, "mysvc", service.Name)
assert.True(t, service.Enabled)
assert.Empty(t, service.Domain, "domain should be empty when not specified")
require.Len(t, service.Targets, 1)
target := service.Targets[0]
assert.Equal(t, 8080, target.Port)
assert.Equal(t, "http", target.Protocol)
assert.Equal(t, "peer-1", target.TargetId)
assert.Equal(t, TargetTypePeer, target.TargetType)
assert.True(t, target.Enabled)
assert.Equal(t, "account-1", target.AccountID)
})
t.Run("with custom domain", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 3000,
Domain: "example.com",
}
service := req.ToService("acc", "peer", "web")
assert.Equal(t, "web.example.com", service.Domain)
})
t.Run("with PIN auth", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 80,
Pin: "1234",
}
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.PinAuth)
assert.True(t, service.Auth.PinAuth.Enabled)
assert.Equal(t, "1234", service.Auth.PinAuth.Pin)
assert.Nil(t, service.Auth.PasswordAuth)
assert.Nil(t, service.Auth.BearerAuth)
})
t.Run("with password auth", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 80,
Password: "secret",
}
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.PasswordAuth)
assert.True(t, service.Auth.PasswordAuth.Enabled)
assert.Equal(t, "secret", service.Auth.PasswordAuth.Password)
})
t.Run("with user groups (bearer auth)", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 80,
UserGroups: []string{"admins", "devs"},
}
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.BearerAuth)
assert.True(t, service.Auth.BearerAuth.Enabled)
assert.Equal(t, []string{"admins", "devs"}, service.Auth.BearerAuth.DistributionGroups)
})
t.Run("with all auth types", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 443,
Domain: "myco.com",
Pin: "9999",
Password: "pass",
UserGroups: []string{"ops"},
}
service := req.ToService("acc", "peer", "full")
assert.Equal(t, "full.myco.com", service.Domain)
require.NotNil(t, service.Auth.PinAuth)
require.NotNil(t, service.Auth.PasswordAuth)
require.NotNil(t, service.Auth.BearerAuth)
})
}

View File

@@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler { return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
if err != nil { if err != nil {
log.Fatalf("failed to create API handler: %v", err) log.Fatalf("failed to create API handler: %v", err)
} }
@@ -134,7 +134,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if s.Config.HttpConfig.LetsEncryptDomain != "" { if s.Config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil { if err != nil {
log.Fatalf("failed to create certificate manager: %v", err) log.Fatalf("failed to create certificate service: %v", err)
} }
transportCredentials := credentials.NewTLS(certManager.TLSConfig()) transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials)) gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
@@ -152,6 +152,11 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if err != nil { if err != nil {
log.Fatalf("failed to create management server: %v", err) log.Fatalf("failed to create management server: %v", err)
} }
serviceMgr := s.ServiceManager()
srv.SetReverseProxyManager(serviceMgr)
if serviceMgr != nil {
serviceMgr.StartExposeReaper(context.Background())
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv) mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer()) mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
@@ -163,9 +168,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager()) proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ReverseProxyManager()) proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())
}) })
return proxyService return proxyService
}) })
@@ -188,12 +194,25 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
return Create(s, func() *nbgrpc.OneTimeTokenStore { return Create(s, func() *nbgrpc.OneTimeTokenStore {
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create proxy token store: %v", err)
}
log.Info("One-time token store initialized for proxy authentication") log.Info("One-time token store initialized for proxy authentication")
return tokenStore return tokenStore
}) })
} }
func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
return Create(s, func() *nbgrpc.PKCEVerifierStore {
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create PKCE verifier store: %v", err)
}
return pkceStore
})
}
func (s *BaseServer) AccessLogsManager() accesslogs.Manager { func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
return Create(s, func() accesslogs.Manager { return Create(s, func() accesslogs.Manager {
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager()) accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())

View File

@@ -6,6 +6,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -106,6 +108,16 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
}) })
} }
func (s *BaseServer) ServiceProxyController() proxy.Controller {
return Create(s, func() proxy.Controller {
controller, err := proxymanager.NewGRPCController(s.ReverseProxyGRPCServer(), s.Metrics().GetMeter())
if err != nil {
log.Fatalf("failed to create service proxy controller: %v", err)
}
return controller
})
}
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
return Create(s, func() *server.AccountRequestBuffer { return Create(s, func() *server.AccountRequestBuffer {
return server.NewAccountRequestBuffer(context.Background(), s.Store()) return server.NewAccountRequestBuffer(context.Background(), s.Store())

View File

@@ -8,9 +8,11 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
@@ -99,11 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager { return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
if err != nil { if err != nil {
log.Fatalf("failed to create account manager: %v", err) log.Fatalf("failed to create account service: %v", err)
} }
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
accountManager.SetServiceManager(s.ReverseProxyManager()) accountManager.SetServiceManager(s.ServiceManager())
}) })
return accountManager return accountManager
@@ -114,28 +116,28 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager { return Create(s, func() idp.Manager {
var idpManager idp.Manager var idpManager idp.Manager
var err error var err error
// Use embedded IdP manager if embedded Dex is configured and enabled. // Use embedded IdP service if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured. // Legacy IdpManager won't be used anymore even if configured.
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled { if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics()) idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil { if err != nil {
log.Fatalf("failed to create embedded IDP manager: %v", err) log.Fatalf("failed to create embedded IDP service: %v", err)
} }
return idpManager return idpManager
} }
// Fall back to external IdP manager // Fall back to external IdP service
if s.Config.IdpManagerConfig != nil { if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil { if err != nil {
log.Fatalf("failed to create IDP manager: %v", err) log.Fatalf("failed to create IDP service: %v", err)
} }
} }
return idpManager return idpManager
}) })
} }
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil // OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider { func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled { if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
return nil return nil
@@ -162,7 +164,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
func (s *BaseServer) ResourcesManager() resources.Manager { func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager { return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager()) return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
}) })
} }
@@ -190,15 +192,25 @@ func (s *BaseServer) RecordsManager() records.Manager {
}) })
} }
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager { func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() reverseproxy.Manager { return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager()) return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
})
}
func (s *BaseServer) ProxyManager() proxy.Manager {
return Create(s, func() proxy.Manager {
manager, err := proxymanager.NewManager(s.Store(), s.Metrics().GetMeter())
if err != nil {
log.Fatalf("failed to create proxy manager: %v", err)
}
return manager
}) })
} }
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager { return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager()) m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
return &m return &m
}) })
} }

View File

@@ -28,9 +28,13 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
// ManagementLegacyPort is the port that was used before by the Management gRPC server. const (
// It is used for backward compatibility now. // ManagementLegacyPort is the port that was used before by the Management gRPC server.
const ManagementLegacyPort = 33073 // It is used for backward compatibility now.
ManagementLegacyPort = 33073
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
DefaultSelfHostedDomain = "netbird.selfhosted"
)
type Server interface { type Server interface {
Start(ctx context.Context) error Start(ctx context.Context) error
@@ -58,6 +62,7 @@ type BaseServer struct {
mgmtMetricsPort int mgmtMetricsPort int
mgmtPort int mgmtPort int
disableLegacyManagementPort bool disableLegacyManagementPort bool
autoResolveDomains bool
proxyAuthClose func() proxyAuthClose func()
@@ -81,6 +86,7 @@ type Config struct {
DisableMetrics bool DisableMetrics bool
DisableGeoliteUpdate bool DisableGeoliteUpdate bool
UserDeleteFromIDPEnabled bool UserDeleteFromIDPEnabled bool
AutoResolveDomains bool
} }
// NewServer initializes and configures a new Server instance // NewServer initializes and configures a new Server instance
@@ -96,6 +102,7 @@ func NewServer(cfg *Config) *BaseServer {
mgmtPort: cfg.MgmtPort, mgmtPort: cfg.MgmtPort,
disableLegacyManagementPort: cfg.DisableLegacyManagementPort, disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
mgmtMetricsPort: cfg.MgmtMetricsPort, mgmtMetricsPort: cfg.MgmtMetricsPort,
autoResolveDomains: cfg.AutoResolveDomains,
} }
} }
@@ -109,6 +116,10 @@ func (s *BaseServer) Start(ctx context.Context) error {
s.cancel = cancel s.cancel = cancel
s.errCh = make(chan error, 4) s.errCh = make(chan error, 4)
if s.autoResolveDomains {
s.resolveDomains(srvCtx)
}
s.PeersManager() s.PeersManager()
s.GeoLocationManager() s.GeoLocationManager()
@@ -157,7 +168,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
// Eagerly create the gRPC server so that all AfterInit hooks are registered // Eagerly create the gRPC server so that all AfterInit hooks are registered
// before we iterate them. Lazy creation after the loop would miss hooks // before we iterate them. Lazy creation after the loop would miss hooks
// registered during GRPCServer() construction (e.g., SetProxyManager). // registered during GRPCServer() construction (e.g., SetServiceManager).
s.GRPCServer() s.GRPCServer()
for _, fn := range s.afterInit { for _, fn := range s.afterInit {
@@ -237,7 +248,6 @@ func (s *BaseServer) Stop() error {
_ = s.certManager.Listener().Close() _ = s.certManager.Listener().Close()
} }
s.GRPCServer().Stop() s.GRPCServer().Stop()
s.ReverseProxyGRPCServer().Close()
if s.proxyAuthClose != nil { if s.proxyAuthClose != nil {
s.proxyAuthClose() s.proxyAuthClose()
s.proxyAuthClose = nil s.proxyAuthClose = nil
@@ -381,6 +391,60 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
}() }()
} }
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
// Fresh installs use the default self-hosted domain, while existing installs reuse the
// persisted account domain to keep addressing stable across config changes.
func (s *BaseServer) resolveDomains(ctx context.Context) {
st := s.Store()
setDefault := func(logMsg string, args ...any) {
if logMsg != "" {
log.WithContext(ctx).Warnf(logMsg, args...)
}
s.dnsDomain = DefaultSelfHostedDomain
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
}
accountsCount, err := st.GetAccountsCounter(ctx)
if err != nil {
setDefault("resolve domains: failed to read accounts counter: %v; using default domain %q", err, DefaultSelfHostedDomain)
return
}
if accountsCount == 0 {
s.dnsDomain = DefaultSelfHostedDomain
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
log.WithContext(ctx).Infof("resolve domains: fresh install detected, using default domain %q", DefaultSelfHostedDomain)
return
}
accountID, err := st.GetAnyAccountID(ctx)
if err != nil {
setDefault("resolve domains: failed to get existing account ID: %v; using default domain %q", err, DefaultSelfHostedDomain)
return
}
if accountID == "" {
setDefault("resolve domains: empty account ID returned for existing accounts; using default domain %q", DefaultSelfHostedDomain)
return
}
domain, _, err := st.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
setDefault("resolve domains: failed to get account domain for account %q: %v; using default domain %q", accountID, err, DefaultSelfHostedDomain)
return
}
if domain == "" {
setDefault("resolve domains: account %q has empty domain; using default domain %q", accountID, DefaultSelfHostedDomain)
return
}
s.dnsDomain = domain
s.mgmtSingleAccModeDomain = domain
log.WithContext(ctx).Infof("resolve domains: using persisted account domain %q", domain)
}
func getInstallationID(ctx context.Context, store store.Store) (string, error) { func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID() installationID := store.GetInstallationID()
if installationID != "" { if installationID != "" {

View File

@@ -0,0 +1,63 @@
package server
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
)
func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), nil)
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
}
func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(1), nil)
mockStore.EXPECT().GetAnyAccountID(gomock.Any()).Return("acc-1", nil)
mockStore.EXPECT().GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "acc-1").Return("vpn.mycompany.com", "", nil)
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
}
func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), errors.New("db failed"))
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
}

View File

@@ -0,0 +1,192 @@
package grpc
import (
"context"
pb "github.com/golang/protobuf/proto" // nolint
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/proto"
internalStatus "github.com/netbirdio/netbird/shared/management/status"
)
// CreateExpose handles a peer request to create a new expose service.
func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
exposeReq := &proto.ExposeServiceRequest{}
peerKey, err := s.parseRequest(ctx, req, exposeReq)
if err != nil {
return nil, err
}
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
if err != nil {
return nil, err
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{
NamePrefix: exposeReq.NamePrefix,
Port: int(exposeReq.Port),
Protocol: exposeProtocolToString(exposeReq.Protocol),
Domain: exposeReq.Domain,
Pin: exposeReq.Pin,
Password: exposeReq.Password,
UserGroups: exposeReq.UserGroups,
})
if err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{
ServiceName: created.ServiceName,
ServiceUrl: created.ServiceURL,
Domain: created.Domain,
})
}
// RenewExpose extends the TTL of an active expose session.
func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
renewReq := &proto.RenewExposeRequest{}
peerKey, err := s.parseRequest(ctx, req, renewReq)
if err != nil {
return nil, err
}
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
if err != nil {
return nil, err
}
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.RenewExposeResponse{})
}
// StopExpose terminates an active expose session.
func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
stopReq := &proto.StopExposeRequest{}
peerKey, err := s.parseRequest(ctx, req, stopReq)
if err != nil {
return nil, err
}
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
if err != nil {
return nil, err
}
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.StopExposeResponse{})
}
func mapExposeError(ctx context.Context, err error) error {
s, ok := internalStatus.FromError(err)
if !ok {
log.WithContext(ctx).Errorf("expose service error: %v", err)
return status.Errorf(codes.Internal, "internal error")
}
switch s.Type() {
case internalStatus.InvalidArgument:
return status.Errorf(codes.InvalidArgument, "%s", s.Message)
case internalStatus.PermissionDenied:
return status.Errorf(codes.PermissionDenied, "%s", s.Message)
case internalStatus.NotFound:
return status.Errorf(codes.NotFound, "%s", s.Message)
case internalStatus.AlreadyExists:
return status.Errorf(codes.AlreadyExists, "%s", s.Message)
case internalStatus.PreconditionFailed:
return status.Errorf(codes.ResourceExhausted, "%s", s.Message)
default:
log.WithContext(ctx).Errorf("expose service error: %v", err)
return status.Errorf(codes.Internal, "internal error")
}
}
func (s *Server) encryptResponse(peerKey wgtypes.Key, msg pb.Message) (*proto.EncryptedMessage, error) {
wgKey, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "internal error")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, wgKey, msg)
if err != nil {
return nil, status.Errorf(codes.Internal, "encrypt response")
}
return &proto.EncryptedMessage{
WgPubKey: wgKey.PublicKey().String(),
Body: encryptedResp,
}, nil
}
func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key) (string, *nbpeer.Peer, error) {
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return "", nil, status.Errorf(codes.Internal, "lookup account for peer")
}
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
if err != nil {
return "", nil, status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return accountID, peer, nil
}
func (s *Server) getReverseProxyManager() rpservice.Manager {
s.reverseProxyMu.RLock()
defer s.reverseProxyMu.RUnlock()
return s.reverseProxyManager
}
// SetReverseProxyManager sets the reverse proxy manager on the server.
func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) {
s.reverseProxyMu.Lock()
defer s.reverseProxyMu.Unlock()
s.reverseProxyManager = mgr
}
func exposeProtocolToString(p proto.ExposeProtocol) string {
switch p {
case proto.ExposeProtocol_EXPOSE_HTTP:
return "http"
case proto.ExposeProtocol_EXPOSE_HTTPS:
return "https"
default:
return "http"
}
}

View File

@@ -1,28 +1,23 @@
package grpc package grpc
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"crypto/subtle" "crypto/subtle"
"encoding/base64" "encoding/base64"
"encoding/hex"
"encoding/json"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
) )
// OneTimeTokenStore manages short-lived, single-use authentication tokens
// for proxy-to-management RPC authentication. Tokens are generated when
// a service is created and must be used exactly once by the proxy
// to authenticate a subsequent RPC call.
type OneTimeTokenStore struct {
tokens map[string]*tokenMetadata
mu sync.RWMutex
cleanup *time.Ticker
cleanupDone chan struct{}
}
// tokenMetadata stores information about a one-time token
type tokenMetadata struct { type tokenMetadata struct {
ServiceID string ServiceID string
AccountID string AccountID string
@@ -30,20 +25,24 @@ type tokenMetadata struct {
CreatedAt time.Time CreatedAt time.Time
} }
// NewOneTimeTokenStore creates a new token store with automatic cleanup // OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
// of expired tokens. The cleanupInterval determines how often expired // Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
// tokens are removed from memory. type OneTimeTokenStore struct {
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore { cache *cache.Cache[string]
store := &OneTimeTokenStore{ ctx context.Context
tokens: make(map[string]*tokenMetadata), }
cleanup: time.NewTicker(cleanupInterval),
cleanupDone: make(chan struct{}), // NewOneTimeTokenStore creates a token store with automatic backend selection
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
} }
// Start background cleanup goroutine return &OneTimeTokenStore{
go store.cleanupExpired() cache: cache.New[string](cacheStore),
ctx: ctx,
return store }, nil
} }
// GenerateToken creates a new cryptographically secure one-time token // GenerateToken creates a new cryptographically secure one-time token
@@ -52,25 +51,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
// //
// Returns the generated token string or an error if random generation fails. // Returns the generated token string or an error if random generation fails.
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) { func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
// Generate 32 bytes (256 bits) of cryptographically secure random data
randomBytes := make([]byte, 32) randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil { if _, err := rand.Read(randomBytes); err != nil {
return "", fmt.Errorf("failed to generate random token: %w", err) return "", fmt.Errorf("failed to generate random token: %w", err)
} }
// Encode as URL-safe base64 for easy transmission in gRPC
token := base64.URLEncoding.EncodeToString(randomBytes) token := base64.URLEncoding.EncodeToString(randomBytes)
hashedToken := hashToken(token)
s.mu.Lock() metadata := &tokenMetadata{
defer s.mu.Unlock()
s.tokens[token] = &tokenMetadata{
ServiceID: serviceID, ServiceID: serviceID,
AccountID: accountID, AccountID: accountID,
ExpiresAt: time.Now().Add(ttl), ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
metadataJSON, err := json.Marshal(metadata)
if err != nil {
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
}
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
return "", fmt.Errorf("failed to store token: %w", err)
}
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)", log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
serviceID, accountID, ttl) serviceID, accountID, ttl)
@@ -88,80 +92,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
// - Account ID doesn't match // - Account ID doesn't match
// - Reverse proxy ID doesn't match // - Reverse proxy ID doesn't match
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error { func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
s.mu.Lock() hashedToken := hashToken(token)
defer s.mu.Unlock()
metadata, exists := s.tokens[token] metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
if !exists { if err != nil {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
serviceID, accountID)
return fmt.Errorf("invalid token") return fmt.Errorf("invalid token")
} }
// Check expiration metadata := &tokenMetadata{}
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
return fmt.Errorf("invalid token metadata")
}
if time.Now().After(metadata.ExpiresAt) { if time.Now().After(metadata.ExpiresAt) {
delete(s.tokens, token) log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
serviceID, accountID)
return fmt.Errorf("token expired") return fmt.Errorf("token expired")
} }
// Validate account ID using constant-time comparison (prevents timing attacks)
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 { if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
metadata.AccountID, accountID)
return fmt.Errorf("account ID mismatch") return fmt.Errorf("account ID mismatch")
} }
// Validate service ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 { if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
metadata.ServiceID, serviceID)
return fmt.Errorf("service ID mismatch") return fmt.Errorf("service ID mismatch")
} }
// Delete token immediately to enforce single-use if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
delete(s.tokens, token) log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
}
log.Infof("Token validated and consumed for proxy %s in account %s", log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
serviceID, accountID)
return nil return nil
} }
// cleanupExpired removes expired tokens in the background to prevent memory leaks func hashToken(token string) string {
func (s *OneTimeTokenStore) cleanupExpired() { hash := sha256.Sum256([]byte(token))
for { return hex.EncodeToString(hash[:])
select {
case <-s.cleanup.C:
s.mu.Lock()
now := time.Now()
removed := 0
for token, metadata := range s.tokens {
if now.After(metadata.ExpiresAt) {
delete(s.tokens, token)
removed++
}
}
if removed > 0 {
log.Debugf("Cleaned up %d expired one-time tokens", removed)
}
s.mu.Unlock()
case <-s.cleanupDone:
return
}
}
}
// Close stops the cleanup goroutine and releases resources
func (s *OneTimeTokenStore) Close() {
s.cleanup.Stop()
close(s.cleanupDone)
}
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
func (s *OneTimeTokenStore) GetTokenCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.tokens)
} }

View File

@@ -0,0 +1,61 @@
package grpc
import (
"context"
"fmt"
"time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
)
// PKCEVerifierStore manages PKCE verifiers for OAuth flows.
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
type PKCEVerifierStore struct {
cache *cache.Cache[string]
ctx context.Context
}
// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection
func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
}
return &PKCEVerifierStore{
cache: cache.New[string](cacheStore),
ctx: ctx,
}, nil
}
// Store saves a PKCE verifier associated with an OAuth state parameter.
// The verifier is stored with the specified TTL and will be automatically deleted after expiration.
func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error {
if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil {
return fmt.Errorf("failed to store PKCE verifier: %w", err)
}
log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl)
return nil
}
// LoadAndDelete retrieves and removes a PKCE verifier for the given state.
// Returns the verifier and true if found, or empty string and false if not found.
// This enforces single-use semantics for PKCE verifiers.
func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) {
verifier, err := s.cache.Get(s.ctx, state)
if err != nil {
log.Debugf("PKCE verifier not found for state")
return "", false
}
if err := s.cache.Delete(s.ctx, state); err != nil {
log.Warnf("Failed to delete PKCE verifier for state: %v", err)
}
return verifier, true
}

View File

@@ -18,14 +18,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
@@ -58,17 +58,17 @@ type ProxyServiceServer struct {
// Map of connected proxies: proxy_id -> proxy connection // Map of connected proxies: proxy_id -> proxy connection
connectedProxies sync.Map connectedProxies sync.Map
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
// Channel for broadcasting reverse proxy updates to all proxies
updatesChan chan *proto.ProxyMapping
// Manager for access logs // Manager for access logs
accessLogManager accesslogs.Manager accessLogManager accesslogs.Manager
// Manager for reverse proxy operations // Manager for reverse proxy operations
reverseProxyManager reverseproxy.Manager serviceManager rpservice.Manager
// ProxyController for service updates and cluster management
proxyController proxy.Controller
// Manager for proxy connections
proxyManager proxy.Manager
// Manager for peers // Manager for peers
peersManager peers.Manager peersManager peers.Manager
@@ -82,84 +82,67 @@ type ProxyServiceServer struct {
// OIDC configuration for proxy authentication // OIDC configuration for proxy authentication
oidcConfig ProxyOIDCConfig oidcConfig ProxyOIDCConfig
// TODO: use database to store these instead? // Store for PKCE verifiers
// pkceVerifiers stores PKCE code verifiers keyed by OAuth state. pkceVerifierStore *PKCEVerifierStore
// Entries expire after pkceVerifierTTL to prevent unbounded growth.
pkceVerifiers sync.Map
pkceCleanupCancel context.CancelFunc
} }
const pkceVerifierTTL = 10 * time.Minute const pkceVerifierTTL = 10 * time.Minute
type pkceEntry struct {
verifier string
createdAt time.Time
}
// proxyConnection represents a connected proxy // proxyConnection represents a connected proxy
type proxyConnection struct { type proxyConnection struct {
proxyID string proxyID string
address string address string
stream proto.ProxyService_GetMappingUpdateServer stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.ProxyMapping sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
} }
// NewProxyServiceServer creates a new proxy service server. // NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer { func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background()) ctx := context.Background()
s := &ProxyServiceServer{ s := &ProxyServiceServer{
updatesChan: make(chan *proto.ProxyMapping, 100),
accessLogManager: accessLogMgr, accessLogManager: accessLogMgr,
oidcConfig: oidcConfig, oidcConfig: oidcConfig,
tokenStore: tokenStore, tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
peersManager: peersManager, peersManager: peersManager,
usersManager: usersManager, usersManager: usersManager,
pkceCleanupCancel: cancel, proxyManager: proxyMgr,
} }
go s.cleanupPKCEVerifiers(ctx) go s.cleanupStaleProxies(ctx)
return s return s
} }
// cleanupPKCEVerifiers periodically removes expired PKCE verifiers. // cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) { func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
ticker := time.NewTicker(pkceVerifierTTL) ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
now := time.Now() if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil {
s.pkceVerifiers.Range(func(key, value any) bool { log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err)
if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL {
s.pkceVerifiers.Delete(key)
} }
return true
})
} }
} }
} }
// Close stops background goroutines. func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
func (s *ProxyServiceServer) Close() { s.serviceManager = manager
s.pkceCleanupCancel()
} }
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) { func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
s.reverseProxyManager = manager s.proxyController = proxyController
} }
// GetMappingUpdate handles the control stream with proxy clients // GetMappingUpdate handles the control stream with proxy clients
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error { func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
ctx := stream.Context() ctx := stream.Context()
peerInfo := "" peerInfo := PeerIPFromContext(ctx)
if p, ok := peer.FromContext(ctx); ok {
peerInfo = p.Addr.String()
}
log.Infof("New proxy connection from %s", peerInfo) log.Infof("New proxy connection from %s", peerInfo)
proxyID := req.GetProxyId() proxyID := req.GetProxyId()
@@ -177,13 +160,21 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
proxyID: proxyID, proxyID: proxyID,
address: proxyAddress, address: proxyAddress,
stream: stream, stream: stream,
sendChan: make(chan *proto.ProxyMapping, 100), sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
ctx: connCtx, ctx: connCtx,
cancel: cancel, cancel: cancel,
} }
s.connectedProxies.Store(proxyID, conn) s.connectedProxies.Store(proxyID, conn)
s.addToCluster(conn.address, proxyID) if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
}
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"proxy_id": proxyID, "proxy_id": proxyID,
"address": proxyAddress, "address": proxyAddress,
@@ -191,8 +182,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
"total_proxies": len(s.GetConnectedProxies()), "total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster") }).Info("Proxy registered in cluster")
defer func() { defer func() {
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
s.connectedProxies.Delete(proxyID) s.connectedProxies.Delete(proxyID)
s.removeFromCluster(conn.address, proxyID) if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
cancel() cancel()
log.Infof("Proxy %s disconnected", proxyID) log.Infof("Proxy %s disconnected", proxyID)
}() }()
@@ -204,6 +202,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
errChan := make(chan error, 2) errChan := make(chan error, 2)
go s.sender(conn, errChan) go s.sender(conn, errChan)
// Start heartbeat goroutine
go s.heartbeat(connCtx, proxyID)
select { select {
case err := <-errChan: case err := <-errChan:
return fmt.Errorf("send update to proxy %s: %w", proxyID, err) return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
@@ -212,10 +213,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
} }
} }
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
}
case <-ctx.Done():
return
}
}
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy. // sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only services matching the proxy's cluster address are sent. // Only services matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
services, err := s.reverseProxyManager.GetGlobalServices(ctx) services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil { if err != nil {
return fmt.Errorf("get services from store: %w", err) return fmt.Errorf("get services from store: %w", err)
} }
@@ -224,7 +242,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return fmt.Errorf("proxy address is invalid") return fmt.Errorf("proxy address is invalid")
} }
var filtered []*reverseproxy.Service var filtered []*rpservice.Service
for _, service := range services { for _, service := range services {
if !service.Enabled { if !service.Enabled {
continue continue
@@ -259,7 +277,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{ Mapping: []*proto.ProxyMapping{
service.ToProtoMapping( service.ToProtoMapping(
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy. rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
token, token,
s.GetOIDCValidationConfig(), s.GetOIDCValidationConfig(),
), ),
@@ -288,7 +306,7 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
for { for {
select { select {
case msg := <-conn.sendChan: case msg := <-conn.sendChan:
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil { if err := conn.stream.Send(msg); err != nil {
errChan <- err errChan <- err
return return
} }
@@ -339,7 +357,7 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
// Management should call this when services are created/updated/removed. // Management should call this when services are created/updated/removed.
// For create/update operations a unique one-time auth token is generated per // For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management. // proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) { func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
log.Debugf("Broadcasting service update to all connected proxy servers") log.Debugf("Broadcasting service update to all connected proxy servers")
s.connectedProxies.Range(func(key, value interface{}) bool { s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection) conn := value.(*proxyConnection)
@@ -349,7 +367,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
} }
select { select {
case conn.sendChan <- msg: case conn.sendChan <- msg:
log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID) log.Debugf("Sent service update to proxy server %s", conn.proxyID)
default: default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
} }
@@ -393,81 +411,75 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
return urls return urls
} }
// addToCluster registers a proxy in a cluster.
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
}
// removeFromCluster removes a proxy from a cluster.
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
}
}
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster. // SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility). // If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
// For create/update operations a unique one-time auth token is generated per // For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management. // proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) { func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
updateResponse := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{update},
}
if clusterAddr == "" { if clusterAddr == "" {
s.SendServiceUpdate(update) s.SendServiceUpdate(updateResponse)
return return
} }
proxySet, ok := s.clusterProxies.Load(clusterAddr) if s.proxyController == nil {
if !ok { log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr)
log.Debugf("No proxies connected for cluster %s", clusterAddr) return
}
proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr)
if len(proxyIDs) == 0 {
log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr)
return return
} }
log.Debugf("Sending service update to cluster %s", clusterAddr) log.Debugf("Sending service update to cluster %s", clusterAddr)
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool { for _, proxyID := range proxyIDs {
proxyID := key.(string)
if connVal, ok := s.connectedProxies.Load(proxyID); ok { if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection) conn := connVal.(*proxyConnection)
msg := s.perProxyMessage(update, proxyID) msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil { if msg == nil {
return true continue
} }
select { select {
case conn.sendChan <- msg: case conn.sendChan <- msg:
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default: default:
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
} }
} }
return true
})
} }
// perProxyMessage returns a copy of update with a fresh one-time token for // perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original message is // create/update operations. For delete operations the original mapping is
// returned unchanged because proxies do not need to authenticate for removal. // used unchanged because proxies do not need to authenticate for removal.
// Returns nil if token generation fails (the proxy should be skipped). // Returns nil if token generation fails (the proxy should be skipped).
func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping { func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" { resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
return update for _, mapping := range update.Mapping {
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
resp = append(resp, mapping)
continue
} }
token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute) token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
if err != nil { if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil return nil
} }
msg := shallowCloneMapping(update) msg := shallowCloneMapping(mapping)
msg.AuthToken = token msg.AuthToken = token
return msg resp = append(resp, msg)
}
return &proto.GetMappingUpdateResponse{
Mapping: resp,
}
} }
// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the // shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the
@@ -486,35 +498,8 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
} }
} }
// GetAvailableClusters returns information about all connected proxy clusters.
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
clusterCounts := make(map[string]int)
s.clusterProxies.Range(func(key, value interface{}) bool {
clusterAddr := key.(string)
proxySet := value.(*sync.Map)
count := 0
proxySet.Range(func(_, _ interface{}) bool {
count++
return true
})
if count > 0 {
clusterCounts[clusterAddr] = count
}
return true
})
clusters := make([]ClusterInfo, 0, len(clusterCounts))
for addr, count := range clusterCounts {
clusters = append(clusters, ClusterInfo{
Address: addr,
ConnectedProxies: count,
})
}
return clusters
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err) log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
@@ -533,7 +518,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
}, nil }, nil
} }
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) { func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) {
switch v := req.GetRequest().(type) { switch v := req.GetRequest().(type) {
case *proto.AuthenticateRequest_Pin: case *proto.AuthenticateRequest_Pin:
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth) return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
@@ -544,7 +529,7 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
} }
} }
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) { func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled { if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID) log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
return false, "", "" return false, "", ""
@@ -558,7 +543,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri
return true, "pin-user", proxyauth.MethodPIN return true, "pin-user", proxyauth.MethodPIN
} }
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) { func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled { if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID) log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
return false, "", "" return false, "", ""
@@ -580,7 +565,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
} }
} }
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) { func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
if !authenticated || service.SessionPrivateKey == "" { if !authenticated || service.SessionPrivateKey == "" {
return "", nil return "", nil
} }
@@ -620,7 +605,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
} }
if certificateIssued { if certificateIssued {
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil { if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp") log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err) return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
} }
@@ -632,7 +617,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
internalStatus := protoStatusToInternal(protoStatus) internalStatus := protoStatusToInternal(protoStatus)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to update service status") log.WithContext(ctx).WithError(err).Error("failed to update service status")
return nil, status.Errorf(codes.Internal, "update service status: %v", err) return nil, status.Errorf(codes.Internal, "update service status: %v", err)
} }
@@ -647,22 +632,22 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
} }
// protoStatusToInternal maps proto status to internal status // protoStatusToInternal maps proto status to internal status
func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus { func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
switch protoStatus { switch protoStatus {
case proto.ProxyStatus_PROXY_STATUS_PENDING: case proto.ProxyStatus_PROXY_STATUS_PENDING:
return reverseproxy.StatusPending return rpservice.StatusPending
case proto.ProxyStatus_PROXY_STATUS_ACTIVE: case proto.ProxyStatus_PROXY_STATUS_ACTIVE:
return reverseproxy.StatusActive return rpservice.StatusActive
case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED: case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED:
return reverseproxy.StatusTunnelNotCreated return rpservice.StatusTunnelNotCreated
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING: case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING:
return reverseproxy.StatusCertificatePending return rpservice.StatusCertificatePending
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED: case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED:
return reverseproxy.StatusCertificateFailed return rpservice.StatusCertificateFailed
case proto.ProxyStatus_PROXY_STATUS_ERROR: case proto.ProxyStatus_PROXY_STATUS_ERROR:
return reverseproxy.StatusError return rpservice.StatusError
default: default:
return reverseproxy.StatusError return rpservice.StatusError
} }
} }
@@ -727,7 +712,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
} }
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection. // Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId()) services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account services: %v", err) log.WithContext(ctx).Errorf("failed to get account services: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
@@ -771,7 +756,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum) state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
codeVerifier := oauth2.GenerateVerifier() codeVerifier := oauth2.GenerateVerifier()
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()}) if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil {
log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err)
return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err)
}
return &proto.GetOIDCURLResponse{ return &proto.GetOIDCURLResponse{
Url: (&oauth2.Config{ Url: (&oauth2.Config{
@@ -790,8 +778,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig {
// GetOIDCValidationConfig returns the OIDC configuration for token validation // GetOIDCValidationConfig returns the OIDC configuration for token validation
// in the format needed by ToProtoMapping. // in the format needed by ToProtoMapping.
func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig { func (s *ProxyServiceServer) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return reverseproxy.OIDCValidationConfig{ return proxy.OIDCValidationConfig{
Issuer: s.oidcConfig.Issuer, Issuer: s.oidcConfig.Issuer,
Audiences: []string{s.oidcConfig.Audience}, Audiences: []string{s.oidcConfig.Audience},
KeysLocation: s.oidcConfig.KeysLocation, KeysLocation: s.oidcConfig.KeysLocation,
@@ -808,18 +796,10 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
// ValidateState validates the state parameter from an OAuth callback. // ValidateState validates the state parameter from an OAuth callback.
// Returns the original redirect URL if valid, or an error if invalid. // Returns the original redirect URL if valid, or an error if invalid.
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) { func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
v, ok := s.pkceVerifiers.LoadAndDelete(state) verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
if !ok { if !ok {
return "", "", errors.New("no verifier for state") return "", "", errors.New("no verifier for state")
} }
entry, ok := v.(pkceEntry)
if !ok {
return "", "", errors.New("invalid verifier for state")
}
if time.Since(entry.createdAt) > pkceVerifierTTL {
return "", "", errors.New("PKCE verifier expired")
}
verifier = entry.verifier
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce) // State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
parts := strings.Split(state, "|") parts := strings.Split(state, "|")
@@ -850,12 +830,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user. // GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key // Find the service by domain to get its signing key
services, err := s.reverseProxyManager.GetGlobalServices(ctx) services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil { if err != nil {
return "", fmt.Errorf("get services: %w", err) return "", fmt.Errorf("get services: %w", err)
} }
var service *reverseproxy.Service var service *rpservice.Service
for _, svc := range services { for _, svc := range services {
if svc.Domain == domain { if svc.Domain == domain {
service = svc service = svc
@@ -921,8 +901,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain) return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
} }
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID) services, err := s.serviceManager.GetAccountServices(ctx, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account services: %w", err) return nil, fmt.Errorf("get account services: %w", err)
} }
@@ -1043,8 +1023,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil }, nil
} }
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) { func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
services, err := s.reverseProxyManager.GetGlobalServices(ctx) services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("get services: %w", err) return nil, fmt.Errorf("get services: %w", err)
} }
@@ -1058,7 +1038,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
return nil, fmt.Errorf("service not found for domain: %s", domain) return nil, fmt.Errorf("service not found for domain: %s", domain)
} }
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error { func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled { if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil return nil
} }

View File

@@ -107,7 +107,7 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter
} }
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) { func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
clientIP := peerIPFromContext(ctx) clientIP := PeerIPFromContext(ctx)
if clientIP != "" && i.failureLimiter.isLimited(clientIP) { if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts") return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")

View File

@@ -115,9 +115,9 @@ func (l *authFailureLimiter) stop() {
l.cancel() l.cancel()
} }
// peerIPFromContext extracts the client IP from the gRPC context. // PeerIPFromContext extracts the client IP from the gRPC context.
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address. // Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
func peerIPFromContext(ctx context.Context) clientIP { func PeerIPFromContext(ctx context.Context) string {
if addr, ok := realip.FromContext(ctx); ok { if addr, ok := realip.FromContext(ctx); ok {
return addr.String() return addr.String()
} }

View File

@@ -8,40 +8,44 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
type mockReverseProxyManager struct { type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.Service proxiesByAccount map[string][]*service.Service
err error err error
} }
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
return nil
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
if m.err != nil { if m.err != nil {
return nil, m.err return nil, m.err
} }
return m.proxiesByAccount[accountID], nil return m.proxiesByAccount[accountID], nil
} }
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
return []*reverseproxy.Service{}, nil return []*service.Service{}, nil
} }
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error { func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
@@ -52,7 +56,7 @@ func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, ac
return nil return nil
} }
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error { func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error {
return nil return nil
} }
@@ -64,14 +68,28 @@ func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID,
return nil return nil
} }
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) { func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) {
return &reverseproxy.Service{}, nil return &service.Service{}, nil
} }
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil return "", nil
} }
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return &service.ExposeServiceResponse{}, nil
}
func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
type mockUsersManager struct { type mockUsersManager struct {
users map[string]*types.User users map[string]*types.User
err error err error
@@ -93,7 +111,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name string name string
domain string domain string
userID string userID string
proxiesByAccount map[string][]*reverseproxy.Service proxiesByAccount map[string][]*service.Service
users map[string]*types.User users map[string]*types.User
proxyErr error proxyErr error
userErr error userErr error
@@ -104,7 +122,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not found", name: "user not found",
domain: "app.example.com", domain: "app.example.com",
userID: "unknown-user", userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}}, "account1": {{Domain: "app.example.com", AccountID: "account1"}},
}, },
users: map[string]*types.User{}, users: map[string]*types.User{},
@@ -115,7 +133,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy not found in user's account", name: "proxy not found in user's account",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{}, proxiesByAccount: map[string][]*service.Service{},
users: map[string]*types.User{ users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"}, "user1": {Id: "user1", AccountID: "account1"},
}, },
@@ -126,7 +144,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy exists in different account - not accessible", name: "proxy exists in different account - not accessible",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}}, "account2": {{Domain: "app.example.com", AccountID: "account2"}},
}, },
users: map[string]*types.User{ users: map[string]*types.User{
@@ -139,8 +157,8 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "no bearer auth configured - same account allows access", name: "no bearer auth configured - same account allows access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}}, "account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}},
}, },
users: map[string]*types.User{ users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"}, "user1": {Id: "user1", AccountID: "account1"},
@@ -151,12 +169,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth disabled - same account allows access", name: "bearer auth disabled - same account allows access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false}, BearerAuth: &service.BearerAuthConfig{Enabled: false},
}, },
}}, }},
}, },
@@ -169,12 +187,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth enabled but no groups configured - same account allows access", name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{}, DistributionGroups: []string{},
}, },
@@ -190,12 +208,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not in allowed groups", name: "user not in allowed groups",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"group1", "group2"}, DistributionGroups: []string{"group1", "group2"},
}, },
@@ -212,12 +230,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in one of the allowed groups - allow access", name: "user in one of the allowed groups - allow access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"group1", "group2"}, DistributionGroups: []string{"group1", "group2"},
}, },
@@ -233,12 +251,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in all allowed groups - allow access", name: "user in all allowed groups - allow access",
domain: "app.example.com", domain: "app.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{ "account1": {{
Domain: "app.example.com", Domain: "app.example.com",
AccountID: "account1", AccountID: "account1",
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"group1", "group2"}, DistributionGroups: []string{"group1", "group2"},
}, },
@@ -266,10 +284,10 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "multiple proxies in account - finds correct one", name: "multiple proxies in account - finds correct one",
domain: "app2.example.com", domain: "app2.example.com",
userID: "user1", userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": { "account1": {
{Domain: "app1.example.com", AccountID: "account1"}, {Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}, {Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"}, {Domain: "app3.example.com", AccountID: "account1"},
}, },
}, },
@@ -283,7 +301,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{ server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{ serviceManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount, proxiesByAccount: tt.proxiesByAccount,
err: tt.proxyErr, err: tt.proxyErr,
}, },
@@ -310,7 +328,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name string name string
accountID string accountID string
domain string domain string
proxiesByAccount map[string][]*reverseproxy.Service proxiesByAccount map[string][]*service.Service
err error err error
expectProxy bool expectProxy bool
expectErr bool expectErr bool
@@ -319,7 +337,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy found", name: "proxy found",
accountID: "account1", accountID: "account1",
domain: "app.example.com", domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": { "account1": {
{Domain: "other.example.com", AccountID: "account1"}, {Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"}, {Domain: "app.example.com", AccountID: "account1"},
@@ -332,7 +350,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy not found in account", name: "proxy not found in account",
accountID: "account1", accountID: "account1",
domain: "unknown.example.com", domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{ proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}}, "account1": {{Domain: "app.example.com", AccountID: "account1"}},
}, },
expectProxy: false, expectProxy: false,
@@ -342,7 +360,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "empty proxy list for account", name: "empty proxy list for account",
accountID: "account1", accountID: "account1",
domain: "app.example.com", domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{}, proxiesByAccount: map[string][]*service.Service{},
expectProxy: false, expectProxy: false,
expectErr: true, expectErr: true,
}, },
@@ -360,7 +378,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{ server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{ serviceManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount, proxiesByAccount: tt.proxiesByAccount,
err: tt.err, err: tt.err,
}, },

View File

@@ -1,6 +1,7 @@
package grpc package grpc
import ( import (
"context"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"strings" "strings"
@@ -11,13 +12,65 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
) )
type testProxyController struct {
mu sync.Mutex
clusterProxies map[string]map[string]struct{}
}
func newTestProxyController() *testProxyController {
return &testProxyController{
clusterProxies: make(map[string]map[string]struct{}),
}
}
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
}
func (c *testProxyController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return proxy.OIDCValidationConfig{}
}
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, clusterAddr, proxyID string) error {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.clusterProxies[clusterAddr]; !ok {
c.clusterProxies[clusterAddr] = make(map[string]struct{})
}
c.clusterProxies[clusterAddr][proxyID] = struct{}{}
return nil
}
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clusterAddr, proxyID string) error {
c.mu.Lock()
defer c.mu.Unlock()
if proxies, ok := c.clusterProxies[clusterAddr]; ok {
delete(proxies, proxyID)
}
return nil
}
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
c.mu.Lock()
defer c.mu.Unlock()
proxies, ok := c.clusterProxies[clusterAddr]
if !ok {
return nil
}
result := make([]string, 0, len(proxies))
for id := range proxies {
result = append(result, id)
}
return result
}
// registerFakeProxy adds a fake proxy connection to the server's internal maps // registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received. // and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping { func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
ch := make(chan *proto.ProxyMapping, 10) ch := make(chan *proto.GetMappingUpdateResponse, 10)
conn := &proxyConnection{ conn := &proxyConnection{
proxyID: proxyID, proxyID: proxyID,
address: clusterAddr, address: clusterAddr,
@@ -25,13 +78,12 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
} }
s.connectedProxies.Store(proxyID, conn) s.connectedProxies.Store(proxyID, conn)
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) _ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID)
proxySet.(*sync.Map).Store(proxyID, struct{}{})
return ch return ch
} }
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping { func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
select { select {
case msg := <-ch: case msg := <-ch:
return msg return msg
@@ -41,24 +93,29 @@ func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
} }
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour) ctx := context.Background()
defer tokenStore.Close() tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
tokenStore: tokenStore, tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100), pkceVerifierStore: pkceStore,
} }
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com" const cluster = "proxy.example.com"
const numProxies = 3 const numProxies = 3
channels := make([]chan *proto.ProxyMapping, numProxies) channels := make([]chan *proto.GetMappingUpdateResponse, numProxies)
for i := range numProxies { for i := range numProxies {
id := "proxy-" + string(rune('a'+i)) id := "proxy-" + string(rune('a'+i))
channels[i] = registerFakeProxy(s, id, cluster) channels[i] = registerFakeProxy(s, id, cluster)
} }
update := &proto.ProxyMapping{ mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1", Id: "service-1",
AccountId: "account-1", AccountId: "account-1",
@@ -68,14 +125,16 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
}, },
} }
s.SendServiceUpdateToCluster(update, cluster) s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
tokens := make([]string, numProxies) tokens := make([]string, numProxies)
for i, ch := range channels { for i, ch := range channels {
msg := drainChannel(ch) resp := drainChannel(ch)
require.NotNil(t, msg, "proxy %d should receive a message", i) require.NotNil(t, resp, "proxy %d should receive a message", i)
assert.Equal(t, update.Domain, msg.Domain) require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i)
assert.Equal(t, update.Id, msg.Id) msg := resp.Mapping[0]
assert.Equal(t, mapping.Domain, msg.Domain)
assert.Equal(t, mapping.Id, msg.Id)
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i) assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
tokens[i] = msg.AuthToken tokens[i] = msg.AuthToken
} }
@@ -96,66 +155,84 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
} }
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour) ctx := context.Background()
defer tokenStore.Close() tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
tokenStore: tokenStore, tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100), pkceVerifierStore: pkceStore,
} }
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com" const cluster = "proxy.example.com"
ch1 := registerFakeProxy(s, "proxy-a", cluster) ch1 := registerFakeProxy(s, "proxy-a", cluster)
ch2 := registerFakeProxy(s, "proxy-b", cluster) ch2 := registerFakeProxy(s, "proxy-b", cluster)
update := &proto.ProxyMapping{ mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "service-1", Id: "service-1",
AccountId: "account-1", AccountId: "account-1",
Domain: "test.example.com", Domain: "test.example.com",
} }
s.SendServiceUpdateToCluster(update, cluster) s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
msg1 := drainChannel(ch1) resp1 := drainChannel(ch1)
msg2 := drainChannel(ch2) resp2 := drainChannel(ch2)
require.NotNil(t, msg1) require.NotNil(t, resp1)
require.NotNil(t, msg2) require.NotNil(t, resp2)
require.Len(t, resp1.Mapping, 1)
require.Len(t, resp2.Mapping, 1)
// Delete operations should not generate tokens // Delete operations should not generate tokens
assert.Empty(t, msg1.AuthToken) assert.Empty(t, resp1.Mapping[0].AuthToken)
assert.Empty(t, msg2.AuthToken) assert.Empty(t, resp2.Mapping[0].AuthToken)
// No tokens should have been created
assert.Equal(t, 0, tokenStore.GetTokenCount())
} }
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour) ctx := context.Background()
defer tokenStore.Close() tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
tokenStore: tokenStore, tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100), pkceVerifierStore: pkceStore,
} }
s.SetProxyController(newTestProxyController())
// Register proxies in different clusters (SendServiceUpdate broadcasts to all) // Register proxies in different clusters (SendServiceUpdate broadcasts to all)
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a") ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b") ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
update := &proto.ProxyMapping{ mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1", Id: "service-1",
AccountId: "account-1", AccountId: "account-1",
Domain: "test.example.com", Domain: "test.example.com",
} }
update := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{mapping},
}
s.SendServiceUpdate(update) s.SendServiceUpdate(update)
msg1 := drainChannel(ch1) resp1 := drainChannel(ch1)
msg2 := drainChannel(ch2) resp2 := drainChannel(ch2)
require.NotNil(t, msg1) require.NotNil(t, resp1)
require.NotNil(t, msg2) require.NotNil(t, resp2)
require.Len(t, resp1.Mapping, 1)
require.Len(t, resp2.Mapping, 1)
msg1 := resp1.Mapping[0]
msg2 := resp2.Mapping[0]
assert.NotEmpty(t, msg1.AuthToken) assert.NotEmpty(t, msg1.AuthToken)
assert.NotEmpty(t, msg2.AuthToken) assert.NotEmpty(t, msg2.AuthToken)
@@ -178,10 +255,15 @@ func generateState(s *ProxyServiceServer, redirectURL string) string {
} }
func TestOAuthState_NeverTheSame(t *testing.T) { func TestOAuthState_NeverTheSame(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{ oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"), HMACKey: []byte("test-hmac-key"),
}, },
pkceVerifierStore: pkceStore,
} }
redirectURL := "https://app.example.com/callback" redirectURL := "https://app.example.com/callback"
@@ -202,31 +284,43 @@ func TestOAuthState_NeverTheSame(t *testing.T) {
} }
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{ oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"), HMACKey: []byte("test-hmac-key"),
}, },
pkceVerifierStore: pkceStore,
} }
// Old format had only 2 parts: base64(url)|hmac // Old format had only 2 parts: base64(url)|hmac
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
require.NoError(t, err)
_, _, err := s.ValidateState("base64url|hmac") _, _, err = s.ValidateState("base64url|hmac")
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state format") assert.Contains(t, err.Error(), "invalid state format")
} }
func TestValidateState_RejectsInvalidHMAC(t *testing.T) { func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{ oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"), HMACKey: []byte("test-hmac-key"),
}, },
pkceVerifierStore: pkceStore,
} }
// Store with tampered HMAC // Store with tampered HMAC
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
require.NoError(t, err)
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac") _, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac")
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature") assert.Contains(t, err.Error(), "invalid state signature")
} }

View File

@@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/shared/management/client/common" "github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
@@ -80,6 +81,9 @@ type Server struct {
syncSem atomic.Int32 syncSem atomic.Int32
syncLimEnabled bool syncLimEnabled bool
syncLim int32 syncLim int32
reverseProxyManager rpservice.Manager
reverseProxyMu sync.RWMutex
} }
// NewServer creates a new Management server // NewServer creates a new Management server

View File

@@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -34,11 +34,18 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir()) testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err) require.NoError(t, err)
proxyManager := &testValidateSessionProxyManager{store: testStore} serviceManager := &testValidateSessionServiceManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore} usersManager := &testValidateSessionUsersManager{store: testStore}
proxyManager := &testValidateSessionProxyManager{}
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager) tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
proxyService.SetProxyManager(proxyManager) require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore) createTestProxies(t, ctx, testStore)
@@ -54,7 +61,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
pubKey, privKey := generateSessionKeyPair(t) pubKey, privKey := generateSessionKeyPair(t)
testProxy := &reverseproxy.Service{ testProxy := &service.Service{
ID: "testProxyId", ID: "testProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "Test Proxy", Name: "Test Proxy",
@@ -62,15 +69,15 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true, Enabled: true,
SessionPrivateKey: privKey, SessionPrivateKey: privKey,
SessionPublicKey: pubKey, SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
}, },
}, },
} }
require.NoError(t, testStore.CreateService(ctx, testProxy)) require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{ restrictedProxy := &service.Service{
ID: "restrictedProxyId", ID: "restrictedProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "Restricted Proxy", Name: "Restricted Proxy",
@@ -78,8 +85,8 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true, Enabled: true,
SessionPrivateKey: privKey, SessionPrivateKey: privKey,
SessionPublicKey: pubKey, SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"allowedGroupId"}, DistributionGroups: []string{"allowedGroupId"},
}, },
@@ -196,7 +203,7 @@ func TestValidateSession_ProxyNotFound(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.False(t, resp.Valid, "Unknown proxy should be denied") assert.False(t, resp.Valid, "Unknown proxy should be denied")
assert.Equal(t, "proxy_not_found", resp.DeniedReason) assert.Equal(t, "service_not_found", resp.DeniedReason)
} }
func TestValidateSession_InvalidToken(t *testing.T) { func TestValidateSession_InvalidToken(t *testing.T) {
@@ -239,62 +246,102 @@ func TestValidateSession_MissingToken(t *testing.T) {
assert.Contains(t, resp.DeniedReason, "missing") assert.Contains(t, resp.DeniedReason, "missing")
} }
type testValidateSessionProxyManager struct { type testValidateSessionServiceManager struct {
store store.Store store store.Store
} }
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error { func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error { func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error { func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error { func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil return nil
} }
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone) return m.store.GetServices(ctx, store.LockingStrengthNone)
} }
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID) return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
} }
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
} }
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil return "", nil
} }
func (m *testValidateSessionServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
type testValidateSessionUsersManager struct { type testValidateSessionUsersManager struct {
store store.Store store store.Store
} }

View File

@@ -15,7 +15,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
@@ -85,7 +85,7 @@ type DefaultAccountManager struct {
proxyController port_forwarding.Controller proxyController port_forwarding.Controller
settingsManager settings.Manager settingsManager settings.Manager
reverseProxyManager reverseproxy.Manager serviceManager service.Manager
// config contains the management server configuration // config contains the management server configuration
config *nbconfig.Config config *nbconfig.Config
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
var _ account.Manager = (*DefaultAccountManager)(nil) var _ account.Manager = (*DefaultAccountManager)(nil)
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) { func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
am.reverseProxyManager = serviceManager am.serviceManager = serviceManager
} }
func isUniqueConstraintError(err error) bool { func isUniqueConstraintError(err error) bool {
@@ -376,6 +376,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID)
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil { if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err return nil, err
} }
@@ -394,7 +395,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
} }
if reloadReverseProxy { if reloadReverseProxy {
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil { if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err) log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
} }
} }
@@ -492,6 +493,21 @@ func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Con
} }
} }
func (am *DefaultAccountManager) handlePeerExposeSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
oldEnabled := oldSettings.PeerExposeEnabled
newEnabled := newSettings.PeerExposeEnabled
if oldEnabled == newEnabled {
return
}
event := activity.AccountPeerExposeEnabled
if !newEnabled {
event = activity.AccountPeerExposeDisabled
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled { if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
@@ -714,6 +730,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
} }
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
}
for _, otherUser := range account.Users { for _, otherUser := range account.Users {
if otherUser.Id == userID { if otherUser.Id == userID {
continue continue
@@ -1358,9 +1379,10 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
if am.singleAccountMode && am.singleAccountModeDomain != "" { if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations. // This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account. // We override incoming domain claims to group users under a single account.
userAuth.Domain = am.singleAccountModeDomain err := am.updateUserAuthWithSingleMode(ctx, &userAuth)
userAuth.DomainCategory = types.PrivateCategory if err != nil {
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") return "", "", err
}
} }
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth) accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
@@ -1393,6 +1415,35 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil return accountID, user.Id, nil
} }
// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account
func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error {
userAuth.DomainCategory = types.PrivateCategory
userAuth.Domain = am.singleAccountModeDomain
accountID, err := am.Store.GetAnyAccountID(ctx)
if err != nil {
if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound {
return err
}
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
return nil
}
if accountID == "" {
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
return nil
}
domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
userAuth.Domain = domain
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
return nil
}
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled. // and propagates changes to peers if group propagation is enabled.
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager // requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager

View File

@@ -1,12 +1,14 @@
package account package account
//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import ( import (
"context" "context"
"net" "net"
"net/netip" "net/netip"
"time" "time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -61,11 +63,11 @@ type Manager interface {
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
@@ -140,5 +142,5 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
SetServiceManager(serviceManager reverseproxy.Manager) SetServiceManager(serviceManager service.Manager)
} }

File diff suppressed because it is too large Load Diff

View File

@@ -15,10 +15,12 @@ import (
"time" "time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/prometheus/client_golang/prometheus/push" "github.com/prometheus/client_golang/prometheus/push"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -27,10 +29,13 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
nbAccount "github.com/netbirdio/netbird/management/server/account" nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/cache"
@@ -1802,12 +1807,12 @@ func TestAccount_Copy(t *testing.T) {
Address: "172.12.6.1/24", Address: "172.12.6.1/24",
}, },
}, },
Services: []*reverseproxy.Service{ Services: []*service.Service{
{ {
ID: "service1", ID: "service1",
Name: "test-service", Name: "test-service",
AccountID: "account1", AccountID: "account1",
Targets: []*reverseproxy.Target{}, Targets: []*service.Target{},
}, },
}, },
NetworkMapCache: &types.NetworkMapBuilder{}, NetworkMapCache: &types.NetworkMapBuilder{},
@@ -3112,6 +3117,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager) peersManager := peers.NewManager(store, permissionsManager)
proxyManager := proxy.NewMockManager(ctrl)
proxyManager.EXPECT().
CleanupStale(gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
ctx := context.Background() ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics) updateManager := update_channel.NewPeersUpdateManager(metrics)
@@ -3122,7 +3133,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err return nil, nil, err
} }
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil)) proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
return manager, updateManager, nil return manager, updateManager, nil
} }
@@ -3951,3 +3967,116 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange") t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
} }
} }
func TestUpdateUserAuthWithSingleMode(t *testing.T) {
t.Run("sets defaults and overrides domain from store", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("account-1", nil)
mockStore.EXPECT().
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
Return("real-domain.com", "private", nil)
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "real-domain.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", nil)
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "fallback.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", status.Errorf(status.NotFound, "no accounts"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "fallback.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", status.Errorf(status.Internal, "db down"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.Error(t, err)
assert.Contains(t, err.Error(), "db down")
// Defaults should still be set before error path
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("account-1", nil)
mockStore.EXPECT().
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
Return("", "", status.Errorf(status.Internal, "query failed"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.Error(t, err)
assert.Contains(t, err.Error(), "query failed")
})
}

View File

@@ -208,6 +208,25 @@ const (
ServiceUpdated Activity = 109 ServiceUpdated Activity = 109
ServiceDeleted Activity = 110 ServiceDeleted Activity = 110
// PeerServiceExposed indicates that a peer exposed a service via the reverse proxy
PeerServiceExposed Activity = 111
// PeerServiceUnexposed indicates that a peer-exposed service was removed
PeerServiceUnexposed Activity = 112
// PeerServiceExposeExpired indicates that a peer-exposed service was removed due to TTL expiration
PeerServiceExposeExpired Activity = 113
// AccountPeerExposeEnabled indicates that a user enabled peer expose for the account
AccountPeerExposeEnabled Activity = 114
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
AccountPeerExposeDisabled Activity = 115
// DomainAdded indicates that a user added a custom domain
DomainAdded Activity = 116
// DomainDeleted indicates that a user deleted a custom domain
DomainDeleted Activity = 117
// DomainValidated indicates that a custom domain was validated
DomainValidated Activity = 118
AccountDeleted Activity = 99999 AccountDeleted Activity = 99999
) )
@@ -345,6 +364,17 @@ var activityMap = map[Activity]Code{
ServiceCreated: {"Service created", "service.create"}, ServiceCreated: {"Service created", "service.create"},
ServiceUpdated: {"Service updated", "service.update"}, ServiceUpdated: {"Service updated", "service.update"},
ServiceDeleted: {"Service deleted", "service.delete"}, ServiceDeleted: {"Service deleted", "service.delete"},
PeerServiceExposed: {"Peer exposed service", "service.peer.expose"},
PeerServiceUnexposed: {"Peer unexposed service", "service.peer.unexpose"},
PeerServiceExposeExpired: {"Peer exposed service expired", "service.peer.expose.expire"},
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
DomainAdded: {"Domain added", "domain.add"},
DomainDeleted: {"Domain deleted", "domain.delete"},
DomainValidated: {"Domain validated", "domain.validate"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@@ -249,7 +249,15 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
switch storeEngine { switch storeEngine {
case types.SqliteStoreEngine: case types.SqliteStoreEngine:
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB)) dbFile := eventSinkDB
if envFile, ok := os.LookupEnv("NB_ACTIVITY_EVENT_SQLITE_FILE"); ok && envFile != "" {
dbFile = envFile
}
connStr := dbFile
if !filepath.IsAbs(dbFile) {
connStr = filepath.Join(dataDir, dbFile)
}
dialector = sqlite.Open(connStr)
case types.PostgresStoreEngine: case types.PostgresStoreEngine:
dsn, ok := os.LookupEnv(postgresDsnEnv) dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok { if !ok {

View File

@@ -425,6 +425,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
var groupIDsToDelete []string var groupIDsToDelete []string
var deletedGroups []*types.Group var deletedGroups []*types.Group
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
@@ -433,7 +438,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
continue continue
} }
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
allErrors = errors.Join(allErrors, err) allErrors = errors.Join(allErrors, err)
continue continue
} }
@@ -621,7 +626,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
return nil return nil
} }
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration { if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID) executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
@@ -641,6 +646,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
return &GroupLinkError{"network resource", group.Resources[0].ID} return &GroupLinkError{"network resource", group.Resources[0].ID}
} }
if slices.Contains(flowGroups, group.ID) {
return &GroupLinkError{"settings", "traffic event logging"}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)} return &GroupLinkError{"route", string(linkedRoute.NetID)}
} }

View File

@@ -12,6 +12,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang/mock/gomock"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -26,6 +27,7 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer" peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -284,6 +286,67 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
} }
} }
func TestDefaultAccountManager_DeleteGroupLinkedToFlowGroup(t *testing.T) {
am, _, err := createManager(t)
require.NoError(t, err)
ctrl := gomock.NewController(t)
settingsMock := settings.NewMockManager(ctrl)
settingsMock.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{FlowGroups: []string{"grp-for-flow"}}, nil).
AnyTimes()
settingsMock.EXPECT().
UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(false, nil).
AnyTimes()
am.settingsManager = settingsMock
_, account, err := initTestGroupAccount(am)
require.NoError(t, err)
grp := &types.Group{
ID: "grp-for-flow",
AccountID: account.Id,
Name: "Group for flow",
Issued: types.GroupIssuedAPI,
Peers: make([]string, 0),
}
require.NoError(t, am.CreateGroup(context.Background(), account.Id, groupAdminUserID, grp))
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, "grp-for-flow")
require.Error(t, err)
var gErr *GroupLinkError
require.ErrorAs(t, err, &gErr)
assert.Equal(t, "settings", gErr.Resource)
assert.Equal(t, "traffic event logging", gErr.Name)
group, err := am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
require.NoError(t, err)
assert.NotNil(t, group)
regularGrp := &types.Group{
ID: "grp-regular",
AccountID: account.Id,
Name: "Regular group",
Issued: types.GroupIssuedAPI,
Peers: make([]string, 0),
}
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, regularGrp)
require.NoError(t, err)
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, []string{"grp-for-flow", "grp-regular"})
require.Error(t, err)
group, err = am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
require.NoError(t, err)
assert.NotNil(t, group)
_, err = am.GetGroup(context.Background(), account.Id, "grp-regular", groupAdminUserID)
assert.Error(t, err)
}
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) { func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) {
accountID := "testingAcc" accountID := "testingAcc"
domain := "example.com" domain := "example.com"
@@ -703,7 +766,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving group linked to network router", func(t *testing.T) { t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store) permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager) groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager) resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager) routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager) networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp" idpmanager "github.com/netbirdio/netbird/management/server/idp"
@@ -73,7 +73,7 @@ const (
) )
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints // Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil { if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
idp.AddEndpoints(accountManager, router) idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router) instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router)
if reverseProxyManager != nil && reverseProxyDomainManager != nil { if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
} }
// Register OAuth callback handler for proxy authentication // Register OAuth callback handler for proxy authentication

View File

@@ -168,6 +168,10 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
} }
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) { func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
if req.Settings.PeerExposeEnabled && len(req.Settings.PeerExposeGroups) == 0 {
return nil, status.Errorf(status.InvalidArgument, "peer expose requires at least one group")
}
returnSettings := &types.Settings{ returnSettings := &types.Settings{
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
@@ -175,6 +179,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
PeerExposeEnabled: req.Settings.PeerExposeEnabled,
PeerExposeGroups: req.Settings.PeerExposeGroups,
} }
if req.Settings.Extra != nil { if req.Settings.Extra != nil {
@@ -336,6 +343,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
JwtAllowGroups: &jwtAllowGroups, JwtAllowGroups: &jwtAllowGroups,
RegularUsersViewBlocked: settings.RegularUsersViewBlocked, RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
PeerExposeEnabled: settings.PeerExposeEnabled,
PeerExposeGroups: settings.PeerExposeGroups,
LazyConnectionEnabled: &settings.LazyConnectionEnabled, LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain, DnsDomain: &settings.DNSDomain,
AutoUpdateVersion: &settings.AutoUpdateVersion, AutoUpdateVersion: &settings.AutoUpdateVersion,

View File

@@ -18,8 +18,8 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -190,7 +190,11 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcServer := newFakeOIDCServer() oidcServer := newFakeOIDCServer()
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute) tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
usersManager := users.NewManager(testStore) usersManager := users.NewManager(testStore)
@@ -205,12 +209,14 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
proxyService := nbgrpc.NewProxyServiceServer( proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{}, &testAccessLogManager{},
tokenStore, tokenStore,
pkceStore,
oidcConfig, oidcConfig,
nil, nil,
usersManager, usersManager,
nil,
) )
proxyService.SetProxyManager(&testServiceManager{store: testStore}) proxyService.SetServiceManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService, nil) handler := NewAuthCallbackHandler(proxyService, nil)
@@ -239,12 +245,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
pubKey := base64.StdEncoding.EncodeToString(pub) pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv) privKey := base64.StdEncoding.EncodeToString(priv)
testProxy := &reverseproxy.Service{ testProxy := &service.Service{
ID: "testProxyId", ID: "testProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "Test Proxy", Name: "Test Proxy",
Domain: "test-proxy.example.com", Domain: "test-proxy.example.com",
Targets: []*reverseproxy.Target{{ Targets: []*service.Target{{
Path: strPtr("/"), Path: strPtr("/"),
Host: "localhost", Host: "localhost",
Port: 8080, Port: 8080,
@@ -254,8 +260,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true, Enabled: true,
}}, }},
Enabled: true, Enabled: true,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"allowedGroupId"}, DistributionGroups: []string{"allowedGroupId"},
}, },
@@ -265,12 +271,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
} }
require.NoError(t, testStore.CreateService(ctx, testProxy)) require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{ restrictedProxy := &service.Service{
ID: "restrictedProxyId", ID: "restrictedProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "Restricted Proxy", Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com", Domain: "restricted-proxy.example.com",
Targets: []*reverseproxy.Target{{ Targets: []*service.Target{{
Path: strPtr("/"), Path: strPtr("/"),
Host: "localhost", Host: "localhost",
Port: 8080, Port: 8080,
@@ -280,8 +286,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true, Enabled: true,
}}, }},
Enabled: true, Enabled: true,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: true, Enabled: true,
DistributionGroups: []string{"restrictedGroupId"}, DistributionGroups: []string{"restrictedGroupId"},
}, },
@@ -291,12 +297,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
} }
require.NoError(t, testStore.CreateService(ctx, restrictedProxy)) require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
noAuthProxy := &reverseproxy.Service{ noAuthProxy := &service.Service{
ID: "noAuthProxyId", ID: "noAuthProxyId",
AccountID: "testAccountId", AccountID: "testAccountId",
Name: "No Auth Proxy", Name: "No Auth Proxy",
Domain: "no-auth-proxy.example.com", Domain: "no-auth-proxy.example.com",
Targets: []*reverseproxy.Target{{ Targets: []*service.Target{{
Path: strPtr("/"), Path: strPtr("/"),
Host: "localhost", Host: "localhost",
Port: 8080, Port: 8080,
@@ -306,8 +312,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true, Enabled: true,
}}, }},
Enabled: true, Enabled: true,
Auth: reverseproxy.AuthConfig{ Auth: service.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{ BearerAuth: &service.BearerAuthConfig{
Enabled: false, Enabled: false,
}, },
}, },
@@ -357,19 +363,23 @@ type testServiceManager struct {
store store.Store store store.Store
} }
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) { func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
return nil
}
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) { func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil return nil, nil
} }
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil return nil, nil
} }
@@ -381,7 +391,7 @@ func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ stri
return nil return nil
} }
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error { func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil return nil
} }
@@ -393,15 +403,15 @@ func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error
return nil return nil
} }
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone) return m.store.GetServices(ctx, store.LockingStrengthNone)
} }
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) { func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID) return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
} }
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
} }
@@ -409,6 +419,20 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri
return "", nil return "", nil
} }
func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return nil, nil
}
func (m *testServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string { func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
t.Helper() t.Helper()

View File

@@ -9,10 +9,13 @@ import (
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -91,12 +94,28 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
} }
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager) if err != nil {
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager) t.Fatalf("Failed to create proxy token store: %v", err)
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager) }
proxyServiceServer.SetProxyManager(reverseProxyManager) pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
am.SetServiceManager(reverseProxyManager) if err != nil {
t.Fatalf("Failed to create PKCE verifier store: %v", err)
}
noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err)
}
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager)
proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked // @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
@@ -114,7 +133,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil) apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create API handler: %v", err) t.Fatalf("Failed to create API handler: %v", err)
} }

View File

@@ -52,7 +52,7 @@ type EmbeddedIdPConfig struct {
// EmbeddedStorageConfig holds storage configuration for the embedded IdP. // EmbeddedStorageConfig holds storage configuration for the embedded IdP.
type EmbeddedStorageConfig struct { type EmbeddedStorageConfig struct {
// Type is the storage type (currently only "sqlite3" is supported) // Type is the storage type: "sqlite3" (default) or "postgres"
Type string Type string
// Config contains type-specific configuration // Config contains type-specific configuration
Config EmbeddedStorageTypeConfig Config EmbeddedStorageTypeConfig
@@ -62,6 +62,8 @@ type EmbeddedStorageConfig struct {
type EmbeddedStorageTypeConfig struct { type EmbeddedStorageTypeConfig struct {
// File is the path to the SQLite database file (for sqlite3 type) // File is the path to the SQLite database file (for sqlite3 type)
File string File string
// DSN is the connection string for postgres
DSN string
} }
// OwnerConfig represents the initial owner/admin user for the embedded IdP. // OwnerConfig represents the initial owner/admin user for the embedded IdP.
@@ -74,6 +76,22 @@ type OwnerConfig struct {
Username string Username string
} }
// buildIdpStorageConfig builds the Dex storage config map based on the storage type.
func buildIdpStorageConfig(storageType string, cfg EmbeddedStorageTypeConfig) (map[string]interface{}, error) {
switch storageType {
case "sqlite3":
return map[string]interface{}{
"file": cfg.File,
}, nil
case "postgres":
return map[string]interface{}{
"dsn": cfg.DSN,
}, nil
default:
return nil, fmt.Errorf("unsupported IdP storage type: %s", storageType)
}
}
// ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig. // ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig.
func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
if c.Issuer == "" { if c.Issuer == "" {
@@ -85,6 +103,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" { if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" {
return nil, fmt.Errorf("storage file is required for sqlite3") return nil, fmt.Errorf("storage file is required for sqlite3")
} }
if c.Storage.Type == "postgres" && c.Storage.Config.DSN == "" {
return nil, fmt.Errorf("storage DSN is required for postgres")
}
storageConfig, err := buildIdpStorageConfig(c.Storage.Type, c.Storage.Config)
if err != nil {
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
}
// Build CLI redirect URIs including the device callback (both relative and absolute) // Build CLI redirect URIs including the device callback (both relative and absolute)
cliRedirectURIs := c.CLIRedirectURIs cliRedirectURIs := c.CLIRedirectURIs
@@ -101,9 +127,7 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
Issuer: c.Issuer, Issuer: c.Issuer,
Storage: dex.Storage{ Storage: dex.Storage{
Type: c.Storage.Type, Type: c.Storage.Type,
Config: map[string]interface{}{ Config: storageConfig,
"file": c.Storage.Config.File,
},
}, },
Web: dex.Web{ Web: dex.Web{
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},

View File

@@ -14,6 +14,7 @@ import (
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/idp/dex" "github.com/netbirdio/netbird/idp/dex"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -51,6 +52,7 @@ type properties map[string]interface{}
type DataSource interface { type DataSource interface {
GetAllAccounts(ctx context.Context) []*types.Account GetAllAccounts(ctx context.Context) []*types.Account
GetStoreEngine() types.Engine GetStoreEngine() types.Engine
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
} }
// ConnManager peer connection manager that holds state for current active connections // ConnManager peer connection manager that holds state for current active connections
@@ -210,6 +212,17 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
rosenpassEnabled int rosenpassEnabled int
localUsers int localUsers int
idpUsers int idpUsers int
embeddedIdpTypes map[string]int
services int
servicesEnabled int
servicesTargets int
servicesStatusActive int
servicesStatusPending int
servicesStatusError int
servicesTargetType map[string]int
servicesAuthPassword int
servicesAuthPin int
servicesAuthOIDC int
) )
start := time.Now() start := time.Now()
metricsProperties := make(properties) metricsProperties := make(properties)
@@ -218,10 +231,14 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
rulesProtocol = make(map[string]int) rulesProtocol = make(map[string]int)
rulesDirection = make(map[string]int) rulesDirection = make(map[string]int)
activeUsersLastDay = make(map[string]struct{}) activeUsersLastDay = make(map[string]struct{})
embeddedIdpTypes = make(map[string]int)
servicesTargetType = make(map[string]int)
uptime = time.Since(w.startupTime).Seconds() uptime = time.Since(w.startupTime).Seconds()
connections := w.connManager.GetAllConnectedPeers() connections := w.connManager.GetAllConnectedPeers()
version = nbversion.NetbirdVersion() version = nbversion.NetbirdVersion()
customDomains, customDomainsValidated, _ := w.dataSource.GetCustomDomainsCounts(ctx)
for _, account := range w.dataSource.GetAllAccounts(ctx) { for _, account := range w.dataSource.GetAllAccounts(ctx) {
accounts++ accounts++
@@ -278,6 +295,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
} else { } else {
idpUsers++ idpUsers++
} }
idpType := extractIdpType(idpID)
embeddedIdpTypes[idpType]++
} }
} }
} }
@@ -331,6 +350,37 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
peerActiveVersions = append(peerActiveVersions, peer.Meta.WtVersion) peerActiveVersions = append(peerActiveVersions, peer.Meta.WtVersion)
} }
} }
for _, service := range account.Services {
services++
if service.Enabled {
servicesEnabled++
}
servicesTargets += len(service.Targets)
switch rpservice.Status(service.Meta.Status) {
case rpservice.StatusActive:
servicesStatusActive++
case rpservice.StatusPending:
servicesStatusPending++
case rpservice.StatusError, rpservice.StatusCertificateFailed, rpservice.StatusTunnelNotCreated:
servicesStatusError++
}
for _, target := range service.Targets {
servicesTargetType[target.TargetType]++
}
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled {
servicesAuthPassword++
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled {
servicesAuthPin++
}
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
servicesAuthOIDC++
}
}
} }
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions) minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
@@ -369,6 +419,27 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["rosenpass_enabled"] = rosenpassEnabled metricsProperties["rosenpass_enabled"] = rosenpassEnabled
metricsProperties["local_users_count"] = localUsers metricsProperties["local_users_count"] = localUsers
metricsProperties["idp_users_count"] = idpUsers metricsProperties["idp_users_count"] = idpUsers
metricsProperties["embedded_idp_count"] = len(embeddedIdpTypes)
metricsProperties["services"] = services
metricsProperties["services_enabled"] = servicesEnabled
metricsProperties["services_targets"] = servicesTargets
metricsProperties["services_status_active"] = servicesStatusActive
metricsProperties["services_status_pending"] = servicesStatusPending
metricsProperties["services_status_error"] = servicesStatusError
metricsProperties["services_auth_password"] = servicesAuthPassword
metricsProperties["services_auth_pin"] = servicesAuthPin
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
metricsProperties["custom_domains"] = customDomains
metricsProperties["custom_domains_validated"] = customDomainsValidated
for targetType, count := range servicesTargetType {
metricsProperties["services_target_type_"+targetType] = count
}
for idpType, count := range embeddedIdpTypes {
metricsProperties["embedded_idp_users_"+idpType] = count
}
for protocol, count := range rulesProtocol { for protocol, count := range rulesProtocol {
metricsProperties["rules_protocol_"+protocol] = count metricsProperties["rules_protocol_"+protocol] = count
@@ -456,6 +527,20 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
return req, cancel, nil return req, cancel, nil
} }
// extractIdpType extracts the IdP type from a Dex connector ID.
// Connector IDs are formatted as "<type>-<xid>" (e.g., "okta-abc123", "zitadel-xyz").
// Returns the type prefix, or "oidc" if no known prefix is found.
func extractIdpType(connectorID string) string {
if connectorID == "local" {
return "local"
}
idx := strings.LastIndex(connectorID, "-")
if idx <= 0 {
return "oidc"
}
return strings.ToLower(connectorID[:idx])
}
func getMinMaxVersion(inputList []string) (string, string) { func getMinMaxVersion(inputList []string) (string, string) {
versions := make([]*version.Version, 0) versions := make([]*version.Version, 0)

View File

@@ -6,6 +6,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/idp/dex" "github.com/netbirdio/netbird/idp/dex"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -27,7 +28,8 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information // GetAllAccounts returns a list of *server.Account for use in tests with predefined information
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
localUserID := dex.EncodeDexUserID("10", "local") localUserID := dex.EncodeDexUserID("10", "local")
idpUserID := dex.EncodeDexUserID("20", "zitadel") idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0")
oidcUserID := dex.EncodeDexUserID("30", "d6jvvp69kmnc73c9pl40")
return []*types.Account{ return []*types.Account{
{ {
Id: "1", Id: "1",
@@ -115,6 +117,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
}, },
}, },
}, },
Services: []*rpservice.Service{
{
ID: "svc1",
Enabled: true,
Targets: []*rpservice.Target{
{TargetType: "peer"},
{TargetType: "host"},
},
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
},
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
},
{
ID: "svc2",
Enabled: false,
Targets: []*rpservice.Target{
{TargetType: "domain"},
},
Auth: rpservice.AuthConfig{
BearerAuth: &rpservice.BearerAuthConfig{Enabled: true},
},
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
},
},
}, },
{ {
Id: "2", Id: "2",
@@ -180,6 +207,13 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
"1": {}, "1": {},
}, },
}, },
oidcUserID: {
Id: oidcUserID,
IsServiceUser: false,
PATs: map[string]*types.PersonalAccessToken{
"1": {},
},
},
}, },
Networks: []*networkTypes.Network{ Networks: []*networkTypes.Network{
{ {
@@ -215,6 +249,11 @@ func (mockDatasource) GetStoreEngine() types.Engine {
return types.FileStoreEngine return types.FileStoreEngine
} }
// GetCustomDomainsCounts returns test custom domain counts.
func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) {
return 3, 2, nil
}
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
func TestGenerateProperties(t *testing.T) { func TestGenerateProperties(t *testing.T) {
ds := mockDatasource{} ds := mockDatasource{}
@@ -247,14 +286,14 @@ func TestGenerateProperties(t *testing.T) {
if properties["rules"] != 4 { if properties["rules"] != 4 {
t.Errorf("expected 4 rules, got %d", properties["rules"]) t.Errorf("expected 4 rules, got %d", properties["rules"])
} }
if properties["users"] != 2 { if properties["users"] != 3 {
t.Errorf("expected 1 users, got %d", properties["users"]) t.Errorf("expected 3 users, got %d", properties["users"])
} }
if properties["setup_keys_usage"] != 2 { if properties["setup_keys_usage"] != 2 {
t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"]) t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"])
} }
if properties["pats"] != 4 { if properties["pats"] != 5 {
t.Errorf("expected 4 personal_access_tokens, got %d", properties["pats"]) t.Errorf("expected 5 personal_access_tokens, got %d", properties["pats"])
} }
if properties["peers_ssh_enabled"] != 2 { if properties["peers_ssh_enabled"] != 2 {
t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"]) t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"])
@@ -338,7 +377,90 @@ func TestGenerateProperties(t *testing.T) {
if properties["local_users_count"] != 1 { if properties["local_users_count"] != 1 {
t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"]) t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"])
} }
if properties["idp_users_count"] != 1 { if properties["idp_users_count"] != 2 {
t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"]) t.Errorf("expected 2 idp_users_count, got %d", properties["idp_users_count"])
}
if properties["embedded_idp_users_local"] != 1 {
t.Errorf("expected 1 embedded_idp_users_local, got %v", properties["embedded_idp_users_local"])
}
if properties["embedded_idp_users_zitadel"] != 1 {
t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"])
}
if properties["embedded_idp_users_oidc"] != 1 {
t.Errorf("expected 1 embedded_idp_users_oidc, got %v", properties["embedded_idp_users_oidc"])
}
if properties["embedded_idp_count"] != 3 {
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
}
if properties["services"] != 2 {
t.Errorf("expected 2 services, got %v", properties["services"])
}
if properties["services_enabled"] != 1 {
t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"])
}
if properties["services_targets"] != 3 {
t.Errorf("expected 3 services_targets, got %v", properties["services_targets"])
}
if properties["services_status_active"] != 1 {
t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"])
}
if properties["services_status_pending"] != 1 {
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
}
if properties["services_status_error"] != 0 {
t.Errorf("expected 0 services_status_error, got %v", properties["services_status_error"])
}
if properties["services_target_type_peer"] != 1 {
t.Errorf("expected 1 services_target_type_peer, got %v", properties["services_target_type_peer"])
}
if properties["services_target_type_host"] != 1 {
t.Errorf("expected 1 services_target_type_host, got %v", properties["services_target_type_host"])
}
if properties["services_target_type_domain"] != 1 {
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
}
if properties["services_auth_password"] != 1 {
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
}
if properties["services_auth_oidc"] != 1 {
t.Errorf("expected 1 services_auth_oidc, got %v", properties["services_auth_oidc"])
}
if properties["services_auth_pin"] != 0 {
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
}
if properties["custom_domains"] != int64(3) {
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
}
if properties["custom_domains_validated"] != int64(2) {
t.Errorf("expected 2 custom_domains_validated, got %v", properties["custom_domains_validated"])
}
}
func TestExtractIdpType(t *testing.T) {
tests := []struct {
connectorID string
expected string
}{
{"okta-abc123def", "okta"},
{"zitadel-d5uv82dra0haedlf6kv0", "zitadel"},
{"entra-xyz789", "entra"},
{"google-abc123", "google"},
{"pocketid-abc123", "pocketid"},
{"microsoft-abc123", "microsoft"},
{"authentik-abc123", "authentik"},
{"keycloak-d5uv82dra0haedlf6kv0", "keycloak"},
{"local", "local"},
{"d6jvvp69kmnc73c9pl40", "oidc"},
{"", "oidc"},
}
for _, tt := range tests {
t.Run(tt.connectorID, func(t *testing.T) {
result := extractIdpType(tt.connectorID)
if result != tt.expected {
t.Errorf("extractIdpType(%q) = %q, want %q", tt.connectorID, result, tt.expected)
}
})
} }
} }

View File

@@ -12,7 +12,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
@@ -148,7 +148,7 @@ type MockAccountManager struct {
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
} }
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) { func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) {
// Mock implementation - no-op // Mock implementation - no-op
} }
@@ -407,7 +407,7 @@ func (am *MockAccountManager) AddPeer(
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) { func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
if am.GetGroupFunc != nil { if am.GetGroupByNameFunc != nil {
return am.GetGroupByNameFunc(ctx, accountID, groupName) return am.GetGroupByNameFunc(ctx, accountID, groupName)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")

View File

@@ -7,7 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
@@ -37,19 +37,19 @@ type managerImpl struct {
permissionsManager permissions.Manager permissionsManager permissions.Manager
groupsManager groups.Manager groupsManager groups.Manager
accountManager account.Manager accountManager account.Manager
reverseProxyManager reverseproxy.Manager serviceManager service.Manager
} }
type mockManager struct { type mockManager struct {
} }
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager { func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
groupsManager: groupsManager, groupsManager: groupsManager,
accountManager: accountManager, accountManager: accountManager,
reverseProxyManager: reverseproxyManager, serviceManager: reverseproxyManager,
} }
} }
@@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them // TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
go func() { go func() {
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID) err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err) log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
} }
@@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID) serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err) return fmt.Errorf("failed to check if resource is used by service: %w", err)
} }

Some files were not shown because too many files have changed in this diff Show More