Merge remote-tracking branch 'origin/main' into proto-ipv6-overlay

# Conflicts:
#	client/android/client.go
#	client/ssh/server/server.go
#	shared/management/proto/management.pb.go
This commit is contained in:
Viktor Liu
2026-04-07 18:35:13 +02:00
137 changed files with 16950 additions and 1827 deletions

View File

@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -154,6 +154,26 @@ builds:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-idp-migrate
dir: tools/idp-migrate
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-idp-migrate
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
@@ -166,6 +186,10 @@ archives:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
- id: netbird-idp-migrate
builds:
- netbird-idp-migrate
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
nfpms:
- maintainer: Netbird <dev@netbird.io>

View File

@@ -1,7 +1,7 @@
## Contributor License Agreement
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany,
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance

View File

@@ -206,7 +206,7 @@ func (c *Client) PeersList() *PeerInfoArray {
IP: p.IP,
IPv6: p.IPv6,
FQDN: p.FQDN,
ConnStatus: p.ConnStatus.String(),
ConnStatus: int(p.ConnStatus),
Routes: PeerRoutes{routes: maps.Keys(p.GetRoutes())},
}
peerInfos[n] = pi

View File

@@ -2,12 +2,21 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
// Connection status constants exported via gomobile.
const (
ConnStatusIdle = int(peer.StatusIdle)
ConnStatusConnecting = int(peer.StatusConnecting)
ConnStatusConnected = int(peer.StatusConnected)
)
// PeerInfo describe information about the peers. It designed for the UI usage
type PeerInfo struct {
IP string
IPv6 string
FQDN string
ConnStatus string // Todo replace to enum
ConnStatus int
Routes PeerRoutes
}

View File

@@ -14,7 +14,9 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
@@ -200,7 +202,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
stream, err := client.ExposeService(ctx, req)
if err != nil {
return fmt.Errorf("expose service: %w", err)
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
}
if err := handleExposeReady(cmd, stream, port); err != nil {
@@ -211,26 +213,31 @@ func exposeFn(cmd *cobra.Command, args []string) error {
}
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
switch strings.ToLower(exposeProtocol) {
case "http":
p, err := expose.ParseProtocolType(exposeProtocol)
if err != nil {
return 0, fmt.Errorf("invalid protocol: %w", err)
}
switch p {
case expose.ProtocolHTTP:
return proto.ExposeProtocol_EXPOSE_HTTP, nil
case "https":
case expose.ProtocolHTTPS:
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
case "tcp":
case expose.ProtocolTCP:
return proto.ExposeProtocol_EXPOSE_TCP, nil
case "udp":
case expose.ProtocolUDP:
return proto.ExposeProtocol_EXPOSE_UDP, nil
case "tls":
case expose.ProtocolTLS:
return proto.ExposeProtocol_EXPOSE_TLS, nil
default:
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
return 0, fmt.Errorf("unhandled protocol type: %d", p)
}
}
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %w", err)
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
}
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)

View File

@@ -41,7 +41,7 @@ func init() {
defaultServiceName = "Netbird"
}
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")

View File

@@ -119,6 +119,10 @@ var installCmd = &cobra.Command{
return err
}
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
svcConfig, err := createServiceConfigForInstall()
if err != nil {
return err
@@ -136,6 +140,10 @@ var installCmd = &cobra.Command{
return fmt.Errorf("install service: %w", err)
}
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
cmd.Println("NetBird service has been installed")
return nil
},
@@ -187,6 +195,10 @@ This command will temporarily stop the service, update its configuration, and re
return err
}
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
wasRunning, err := isServiceRunning()
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
return fmt.Errorf("check service status: %w", err)
@@ -222,6 +234,10 @@ This command will temporarily stop the service, update its configuration, and re
return fmt.Errorf("install service with new config: %w", err)
}
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
if wasRunning {
cmd.Println("Starting NetBird service...")
if err := s.Start(); err != nil {

View File

@@ -0,0 +1,201 @@
//go:build !ios && !android
package cmd
import (
"context"
"encoding/json"
"fmt"
"maps"
"os"
"path/filepath"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/util"
)
const serviceParamsFile = "service.json"
// serviceParams holds install-time service parameters that persist across
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
type serviceParams struct {
LogLevel string `json:"log_level"`
DaemonAddr string `json:"daemon_addr"`
ManagementURL string `json:"management_url,omitempty"`
ConfigPath string `json:"config_path,omitempty"`
LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
// serviceParamsPath returns the path to the service params file.
func serviceParamsPath() string {
return filepath.Join(configs.StateDir, serviceParamsFile)
}
// loadServiceParams reads saved service parameters from disk.
// Returns nil with no error if the file does not exist.
func loadServiceParams() (*serviceParams, error) {
path := serviceParamsPath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil //nolint:nilnil
}
return nil, fmt.Errorf("read service params %s: %w", path, err)
}
var params serviceParams
if err := json.Unmarshal(data, &params); err != nil {
return nil, fmt.Errorf("parse service params %s: %w", path, err)
}
return &params, nil
}
// saveServiceParams writes current service parameters to disk atomically
// with restricted permissions.
func saveServiceParams(params *serviceParams) error {
path := serviceParamsPath()
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
return fmt.Errorf("save service params: %w", err)
}
return nil
}
// currentServiceParams captures the current state of all package-level
// variables into a serviceParams struct.
func currentServiceParams() *serviceParams {
params := &serviceParams{
LogLevel: logLevel,
DaemonAddr: daemonAddr,
ManagementURL: managementURL,
ConfigPath: configPath,
LogFiles: logFiles,
DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled,
}
if len(serviceEnvVars) > 0 {
parsed, err := parseServiceEnvVars(serviceEnvVars)
if err == nil && len(parsed) > 0 {
params.ServiceEnvVars = parsed
}
}
return params
}
// loadAndApplyServiceParams loads saved params from disk and applies them
// to any flags that were not explicitly set.
func loadAndApplyServiceParams(cmd *cobra.Command) error {
params, err := loadServiceParams()
if err != nil {
return err
}
applyServiceParams(cmd, params)
return nil
}
// applyServiceParams merges saved parameters into package-level variables
// for any flag that was not explicitly set by the user (via CLI or env var).
// Flags that were Changed() are left untouched.
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
if params == nil {
return
}
// For fields with non-empty defaults (log-level, daemon-addr), keep the
// != "" guard so that an older service.json missing the field doesn't
// clobber the default with an empty string.
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
logLevel = params.LogLevel
}
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
daemonAddr = params.DaemonAddr
}
// For optional fields where empty means "use default", always apply so
// that an explicit clear (--management-url "") persists across reinstalls.
if !rootCmd.PersistentFlags().Changed("management-url") {
managementURL = params.ManagementURL
}
if !rootCmd.PersistentFlags().Changed("config") {
configPath = params.ConfigPath
}
if !rootCmd.PersistentFlags().Changed("log-file") {
logFiles = params.LogFiles
}
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
profilesDisabled = params.DisableProfiles
}
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
updateSettingsDisabled = params.DisableUpdateSettings
}
applyServiceEnvParams(cmd, params)
}
// applyServiceEnvParams merges saved service environment variables.
// If --service-env was explicitly set, explicit values win on key conflict
// but saved keys not in the explicit set are carried over.
// If --service-env was not set, saved env vars are used entirely.
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
if len(params.ServiceEnvVars) == 0 {
return
}
if !cmd.Flags().Changed("service-env") {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
return
}
// Explicit env vars were provided: merge saved values underneath.
explicit, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
return
}
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
maps.Copy(merged, params.ServiceEnvVars)
maps.Copy(merged, explicit) // explicit wins on conflict
serviceEnvVars = envMapToSlice(merged)
}
var resetParamsCmd = &cobra.Command{
Use: "reset-params",
Short: "Remove saved service install parameters",
Long: "Removes the saved service.json file so the next install uses default parameters.",
RunE: func(cmd *cobra.Command, args []string) error {
path := serviceParamsPath()
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
cmd.Println("No saved service parameters found")
return nil
}
return fmt.Errorf("remove service params: %w", err)
}
cmd.Printf("Removed saved service parameters (%s)\n", path)
return nil
},
}
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
func envMapToSlice(m map[string]string) []string {
s := make([]string, 0, len(m))
for k, v := range m {
s = append(s, k+"="+v)
}
return s
}

View File

@@ -0,0 +1,523 @@
//go:build !ios && !android
package cmd
import (
"encoding/json"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"strings"
"testing"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/configs"
)
func TestServiceParamsPath(t *testing.T) {
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = "/var/lib/netbird"
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
configs.StateDir = "/custom/state"
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
}
func TestSaveAndLoadServiceParams(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params := &serviceParams{
LogLevel: "debug",
DaemonAddr: "unix:///var/run/netbird.sock",
ManagementURL: "https://my.server.com",
ConfigPath: "/etc/netbird/config.json",
LogFiles: []string{"/var/log/netbird/client.log", "console"},
DisableProfiles: true,
DisableUpdateSettings: false,
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
}
err := saveServiceParams(params)
require.NoError(t, err)
// Verify the file exists and is valid JSON.
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
require.NoError(t, err)
assert.True(t, json.Valid(data))
loaded, err := loadServiceParams()
require.NoError(t, err)
require.NotNil(t, loaded)
assert.Equal(t, params.LogLevel, loaded.LogLevel)
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
assert.Equal(t, params.LogFiles, loaded.LogFiles)
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
}
func TestLoadServiceParams_FileNotExists(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params, err := loadServiceParams()
assert.NoError(t, err)
assert.Nil(t, params)
}
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
require.NoError(t, err)
params, err := loadServiceParams()
assert.Error(t, err)
assert.Nil(t, params)
}
func TestCurrentServiceParams(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
logLevel = "trace"
daemonAddr = "tcp://127.0.0.1:9999"
managementURL = "https://mgmt.example.com"
configPath = "/tmp/test-config.json"
logFiles = []string{"/tmp/test.log"}
profilesDisabled = true
updateSettingsDisabled = true
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
params := currentServiceParams()
assert.Equal(t, "trace", params.LogLevel)
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
assert.True(t, params.DisableProfiles)
assert.True(t, params.DisableUpdateSettings)
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
}
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
// Reset all flags to defaults.
logLevel = "info"
daemonAddr = "unix:///var/run/netbird.sock"
managementURL = ""
configPath = "/etc/netbird/config.json"
logFiles = []string{"/var/log/netbird/client.log"}
profilesDisabled = false
updateSettingsDisabled = false
serviceEnvVars = nil
// Reset Changed state on all relevant flags.
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Simulate user explicitly setting --log-level via CLI.
logLevel = "warn"
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
saved := &serviceParams{
LogLevel: "debug",
DaemonAddr: "tcp://127.0.0.1:5555",
ManagementURL: "https://saved.example.com",
ConfigPath: "/saved/config.json",
LogFiles: []string{"/saved/client.log"},
DisableProfiles: true,
DisableUpdateSettings: true,
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
// log-level was Changed, so it should keep "warn", not use saved "debug".
assert.Equal(t, "warn", logLevel)
// All other fields were not Changed, so they should use saved values.
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
assert.Equal(t, "https://saved.example.com", managementURL)
assert.Equal(t, "/saved/config.json", configPath)
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
assert.True(t, profilesDisabled)
assert.True(t, updateSettingsDisabled)
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
}
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
t.Cleanup(func() {
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
})
// Simulate current state where booleans are true (e.g. set by previous install).
profilesDisabled = true
updateSettingsDisabled = true
// Reset Changed state so flags appear unset.
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Saved params have both as false.
saved := &serviceParams{
DisableProfiles: false,
DisableUpdateSettings: false,
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.False(t, profilesDisabled, "saved false should override current true")
assert.False(t, updateSettingsDisabled, "saved false should override current true")
}
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
origManagementURL := managementURL
t.Cleanup(func() { managementURL = origManagementURL })
managementURL = "https://leftover.example.com"
// Simulate saved params where management URL was explicitly cleared.
saved := &serviceParams{
LogLevel: "info",
DaemonAddr: "unix:///var/run/netbird.sock",
// ManagementURL intentionally empty: was cleared with --management-url "".
}
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
}
func TestApplyServiceParams_NilParams(t *testing.T) {
origLogLevel := logLevel
t.Cleanup(func() { logLevel = origLogLevel })
logLevel = "info"
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
// Should be a no-op.
applyServiceParams(cmd, nil)
assert.Equal(t, "info", logLevel)
}
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Set up a command with --service-env marked as Changed.
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
saved := &serviceParams{
ServiceEnvVars: map[string]string{
"SAVED": "val",
"OVERLAP": "saved",
},
}
applyServiceEnvParams(cmd, saved)
// Parse result for easier assertion.
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, "yes", result["EXPLICIT"])
assert.Equal(t, "val", result["SAVED"])
// Explicit wins on conflict.
assert.Equal(t, "explicit", result["OVERLAP"])
}
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
serviceEnvVars = nil
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
saved := &serviceParams{
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
}
applyServiceEnvParams(cmd, saved)
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
}
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
// added to serviceParams but not wired into these functions, this test fails.
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
// Collect all JSON field names from the serviceParams struct.
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
// Collect field names referenced in currentServiceParams and applyServiceParams.
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
// applyServiceEnvParams handles ServiceEnvVars indirectly.
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
for k, v := range applyEnvFields {
applyFields[k] = v
}
for _, field := range structFields {
assert.Contains(t, currentFields, field,
"serviceParams field %q is not captured in currentServiceParams()", field)
assert.Contains(t, applyFields, field,
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
}
}
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
// it flows through newSVCConfig() EnvVars, not CLI args.
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields)
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
require.NoError(t, err)
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
fieldsNotInArgs := map[string]bool{
"ServiceEnvVars": true,
}
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
// Forward: every struct field must appear in buildServiceArguments.
for _, field := range structFields {
if fieldsNotInArgs[field] {
continue
}
globalVar := fieldToGlobalVar(field)
assert.Contains(t, buildFields, globalVar,
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
}
// Reverse: every service-related global used in buildServiceArguments must
// have a corresponding serviceParams field. This catches a developer adding
// a new flag to buildServiceArguments without adding it to the struct.
globalToField := make(map[string]string, len(structFields))
for _, field := range structFields {
globalToField[fieldToGlobalVar(field)] = field
}
// Identifiers in buildServiceArguments that are not service params
// (builtins, boilerplate, loop variables).
nonParamGlobals := map[string]bool{
"args": true, "append": true, "string": true, "_": true,
"logFile": true, // range variable over logFiles
}
for ref := range buildFields {
if nonParamGlobals[ref] {
continue
}
_, inStruct := globalToField[ref]
assert.True(t, inStruct,
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
}
}
// extractStructJSONFields returns field names from a named struct type.
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
t.Helper()
var fields []string
ast.Inspect(file, func(n ast.Node) bool {
ts, ok := n.(*ast.TypeSpec)
if !ok || ts.Name.Name != structName {
return true
}
st, ok := ts.Type.(*ast.StructType)
if !ok {
return false
}
for _, f := range st.Fields.List {
if len(f.Names) > 0 {
fields = append(fields, f.Names[0].Name)
}
}
return false
})
return fields
}
// extractFuncFieldRefs returns which of the given field names appear inside the
// named function, either as selector expressions (params.FieldName) or as
// composite literal keys (&serviceParams{FieldName: ...}).
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
t.Helper()
fieldSet := make(map[string]bool, len(fields))
for _, f := range fields {
fieldSet[f] = true
}
found := make(map[string]bool)
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
ast.Inspect(fn.Body, func(n ast.Node) bool {
switch v := n.(type) {
case *ast.SelectorExpr:
if fieldSet[v.Sel.Name] {
found[v.Sel.Name] = true
}
case *ast.KeyValueExpr:
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
found[ident.Name] = true
}
}
return true
})
return found
}
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
t.Helper()
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
refs := make(map[string]bool)
ast.Inspect(fn.Body, func(n ast.Node) bool {
if ident, ok := n.(*ast.Ident); ok {
refs[ident.Name] = true
}
return true
})
return refs
}
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
for _, decl := range file.Decls {
fn, ok := decl.(*ast.FuncDecl)
if ok && fn.Name.Name == name {
return fn
}
}
return nil
}
// fieldToGlobalVar maps serviceParams field names to the package-level variable
// names used in buildServiceArguments and applyServiceParams.
func fieldToGlobalVar(field string) string {
m := map[string]string{
"LogLevel": "logLevel",
"DaemonAddr": "daemonAddr",
"ManagementURL": "managementURL",
"ConfigPath": "configPath",
"LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled",
"ServiceEnvVars": "serviceEnvVars",
}
if v, ok := m[field]; ok {
return v
}
// Default: lowercase first letter.
return strings.ToLower(field[:1]) + field[1:]
}
func TestEnvMapToSlice(t *testing.T) {
m := map[string]string{"A": "1", "B": "2"}
s := envMapToSlice(m)
assert.Len(t, s, 2)
assert.Contains(t, s, "A=1")
assert.Contains(t, s, "B=2")
}
func TestEnvMapToSlice_Empty(t *testing.T) {
s := envMapToSlice(map[string]string{})
assert.Empty(t, s)
}

View File

@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"syscall"
"testing"
"time"
@@ -13,6 +15,22 @@ import (
"github.com/stretchr/testify/require"
)
// TestMain intercepts when this test binary is run as a daemon subprocess.
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
// "service run ..." arguments. Since the test binary can't handle cobra CLI
// args, it exits immediately, causing daemon -r to respawn rapidly until
// hitting the rate limit and exiting. This makes service restart unreliable.
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
<-sig
return
}
os.Exit(m.Run())
}
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
@@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) {
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {

View File

@@ -33,14 +33,14 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
@@ -378,6 +378,32 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL.
// It returns an ExposeSession. Call Wait on the session to keep it alive.
func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
mgr := engine.GetExposeManager()
if mgr == nil {
return nil, fmt.Errorf("expose manager not available")
}
resp, err := mgr.Expose(ctx, req)
if err != nil {
return nil, fmt.Errorf("expose: %w", err)
}
return &ExposeSession{
Domain: resp.Domain,
ServiceName: resp.ServiceName,
ServiceURL: resp.ServiceURL,
mgr: mgr,
}, nil
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()

45
client/embed/expose.go Normal file
View File

@@ -0,0 +1,45 @@
package embed
import (
"context"
"errors"
"github.com/netbirdio/netbird/client/internal/expose"
)
const (
// ExposeProtocolHTTP exposes the service as HTTP.
ExposeProtocolHTTP = expose.ProtocolHTTP
// ExposeProtocolHTTPS exposes the service as HTTPS.
ExposeProtocolHTTPS = expose.ProtocolHTTPS
// ExposeProtocolTCP exposes the service as TCP.
ExposeProtocolTCP = expose.ProtocolTCP
// ExposeProtocolUDP exposes the service as UDP.
ExposeProtocolUDP = expose.ProtocolUDP
// ExposeProtocolTLS exposes the service as TLS.
ExposeProtocolTLS = expose.ProtocolTLS
)
// ExposeRequest is a request to expose a local service via the NetBird reverse proxy.
type ExposeRequest = expose.Request
// ExposeProtocolType represents the protocol used for exposing a service.
type ExposeProtocolType = expose.ProtocolType
// ExposeSession represents an active expose session. Use Wait to block until the session ends.
type ExposeSession struct {
Domain string
ServiceName string
ServiceURL string
mgr *expose.Manager
}
// Wait blocks while keeping the expose session alive.
// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session.
func (s *ExposeSession) Wait(ctx context.Context) error {
if s == nil || s.mgr == nil {
return errors.New("expose session is not initialized")
}
return s.mgr.KeepAlive(ctx, s.Domain)
}

View File

@@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
// case where an interface comes up between refreshes.
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}

View File

@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
err := a.doMgmLogin(client, ctx, pubSSHKey)
if isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
protoFlow, err := client.GetPKCEAuthorizationFlow()
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
@@ -221,7 +215,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
protoFlow, err := client.GetDeviceAuthorizationFlow()
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
@@ -266,7 +254,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
return err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err

View File

@@ -47,6 +47,10 @@ import (
"github.com/netbirdio/netbird/version"
)
// androidRunOverride is set on Android to inject mobile dependencies
// when using embed.Client (which calls Run() with empty MobileDependency).
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
@@ -79,6 +83,9 @@ func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
if androidRunOverride != nil {
return androidRunOverride(c, runningChan, logPath)
}
return c.run(MobileDependency{}, runningChan, logPath)
}
@@ -625,12 +632,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
@@ -650,12 +651,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {
return nil, err
}
return loginResp, nil
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
}
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {

View File

@@ -0,0 +1,73 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
// It returns an empty interface list, which means ICE P2P candidates won't be
// discovered — connections will fall back to relay. Applications that need P2P
// should provide a real implementation via runOnAndroidEmbed that uses
// Android's ConnectivityManager to enumerate network interfaces.
type noopIFaceDiscover struct{}
func (noopIFaceDiscover) IFaces() (string, error) {
// Return empty JSON array — no local interfaces advertised for ICE.
// This is intentional: without Android's ConnectivityManager, we cannot
// reliably enumerate interfaces (netlink is restricted on Android 11+).
// Relay connections still work; only P2P hole-punching is disabled.
return "[]", nil
}
// noopNetworkChangeListener is a stub for embed.Client on Android.
// Network change events are ignored since the embed client manages its own
// reconnection logic via the engine's built-in retry mechanism.
type noopNetworkChangeListener struct{}
func (noopNetworkChangeListener) OnNetworkChanged(string) {
// No-op: embed.Client relies on the engine's internal reconnection
// logic rather than OS-level network change notifications.
}
func (noopNetworkChangeListener) SetInterfaceIP(string) {
// No-op: in netstack mode, the overlay IP is managed by the userspace
// network stack, not by OS-level interface configuration.
}
// noopDnsReadyListener is a stub for embed.Client on Android.
// DNS readiness notifications are not needed in netstack/embed mode
// since system DNS is disabled and DNS resolution happens externally.
type noopDnsReadyListener struct{}
func (noopDnsReadyListener) OnReady() {
// No-op: embed.Client does not need DNS readiness notifications.
// System DNS is disabled in netstack mode.
}
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
var _ dns.ReadyListener = noopDnsReadyListener{}
func init() {
// Wire up the default override so embed.Client.Start() works on Android
// with netstack mode. Provides complete no-op stubs for all mobile
// dependencies so the engine's existing Android code paths work unchanged.
// Applications that need P2P ICE or real DNS should replace this by
// setting androidRunOverride before calling Start().
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
return c.runOnAndroidEmbed(
noopIFaceDiscover{},
noopNetworkChangeListener{},
[]netip.AddrPort{},
noopDnsReadyListener{},
runningChan,
logPath,
)
}
}

View File

@@ -0,0 +1,32 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
// so embed.Client.Start() can detect when the engine is ready.
// It provides complete MobileDependency so the engine's existing
// Android code paths work unchanged.
func (c *ConnectClient) runOnAndroidEmbed(
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
runningChan chan struct{},
logPath string,
) error {
mobileDependency := MobileDependency{
IFaceDiscover: iFaceDiscover,
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, runningChan, logPath)
}

View File

@@ -504,7 +504,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetClientRoutes() {
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
@@ -1914,6 +1914,11 @@ func (e *Engine) GetExposeManager() *expose.Manager {
return e.exposeManager
}
// IsBlockInbound returns whether inbound connections are blocked.
func (e *Engine) IsBlockInbound() bool {
return e.config.BlockInbound
}
// GetClientMetrics returns the client metrics
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics

View File

@@ -829,7 +829,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
@@ -1036,7 +1036,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
@@ -1539,13 +1539,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
return nil, err
}
publicKey, err := mgmtClient.GetServerPublicKey()
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
@@ -1567,7 +1562,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,

View File

@@ -4,11 +4,14 @@ import (
"context"
"time"
mgm "github.com/netbirdio/netbird/shared/management/client"
log "github.com/sirupsen/logrus"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
const renewTimeout = 10 * time.Second
const (
renewTimeout = 10 * time.Second
)
// Response holds the response from exposing a service.
type Response struct {
@@ -18,11 +21,13 @@ type Response struct {
PortAutoAssigned bool
}
// Request holds the parameters for exposing a local service via the management server.
// It is part of the embed API surface and exposed via a type alias.
type Request struct {
NamePrefix string
Domain string
Port uint16
Protocol int
Protocol ProtocolType
Pin string
Password string
UserGroups []string
@@ -59,6 +64,8 @@ func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
return fromClientExposeResponse(resp), nil
}
// KeepAlive periodically renews the expose session for the given domain until the context is canceled or an error occurs.
// It is part of the embed API surface and exposed via a type alias.
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()

View File

@@ -86,7 +86,7 @@ func TestNewRequest(t *testing.T) {
exposeReq := NewRequest(req)
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
assert.Equal(t, ProtocolType(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
assert.Equal(t, "secret", exposeReq.Password, "password should match")
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")

View File

@@ -0,0 +1,40 @@
package expose
import (
"fmt"
"strings"
)
// ProtocolType represents the protocol used for exposing a service.
type ProtocolType int
const (
// ProtocolHTTP exposes the service as HTTP.
ProtocolHTTP ProtocolType = 0
// ProtocolHTTPS exposes the service as HTTPS.
ProtocolHTTPS ProtocolType = 1
// ProtocolTCP exposes the service as TCP.
ProtocolTCP ProtocolType = 2
// ProtocolUDP exposes the service as UDP.
ProtocolUDP ProtocolType = 3
// ProtocolTLS exposes the service as TLS.
ProtocolTLS ProtocolType = 4
)
// ParseProtocolType parses a protocol string into a ProtocolType.
func ParseProtocolType(s string) (ProtocolType, error) {
switch strings.ToLower(s) {
case "http":
return ProtocolHTTP, nil
case "https":
return ProtocolHTTPS, nil
case "tcp":
return ProtocolTCP, nil
case "udp":
return ProtocolUDP, nil
case "tls":
return ProtocolTLS, nil
default:
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", s)
}
}

View File

@@ -9,7 +9,7 @@ import (
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
return &Request{
Port: uint16(req.Port),
Protocol: int(req.Protocol),
Protocol: ProtocolType(req.Protocol),
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
@@ -24,7 +24,7 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest {
NamePrefix: req.NamePrefix,
Domain: req.Domain,
Port: req.Port,
Protocol: req.Protocol,
Protocol: int(req.Protocol),
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,

View File

@@ -39,6 +39,18 @@ const (
DefaultAdminURL = "https://app.netbird.io:443"
)
// mgmProber is the subset of management client needed for URL migration probes.
type mgmProber interface {
HealthCheck() error
Close() error
}
// newMgmProber creates a management client for probing URL reachability.
// Overridden in tests to avoid real network calls.
var newMgmProber = func(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (mgmProber, error) {
return mgm.NewClient(ctx, addr, key, tlsEnabled)
}
var DefaultInterfaceBlacklist = []string{
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
@@ -765,21 +777,19 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
return config, err
}
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
client, err := newMgmProber(ctx, newURL.Host, key, mgmTlsEnabled)
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, err
}
defer func() {
err = client.Close()
if err != nil {
if err := client.Close(); err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
// gRPC check
_, err = client.GetServerPublicKey()
if err != nil {
if err = client.HealthCheck(); err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return nil, err
}

View File

@@ -10,12 +10,21 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/util"
)
type mockMgmProber struct{}
func (m *mockMgmProber) HealthCheck() error {
return nil
}
func (m *mockMgmProber) Close() error { return nil }
func TestGetConfig(t *testing.T) {
// case 1: new default config has to be generated
config, err := UpdateOrCreateConfig(ConfigInput{
@@ -234,6 +243,12 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
}
func TestUpdateOldManagementURL(t *testing.T) {
origProber := newMgmProber
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
return &mockMgmProber{}, nil
}
t.Cleanup(func() { newMgmProber = origProber })
tests := []struct {
name string
previousManagementURL string
@@ -273,18 +288,17 @@ func TestUpdateOldManagementURL(t *testing.T) {
ConfigPath: configPath,
})
require.NoError(t, err, "failed to create testing config")
previousStats, err := os.Stat(configPath)
require.NoError(t, err, "failed to create testing config stats")
previousContent, err := os.ReadFile(configPath)
require.NoError(t, err, "failed to read initial config")
resultConfig, err := UpdateOldManagementURL(context.TODO(), config, configPath)
require.NoError(t, err, "got error when updating old management url")
require.Equal(t, tt.expectedManagementURL, resultConfig.ManagementURL.String())
newStats, err := os.Stat(configPath)
require.NoError(t, err, "failed to create testing config stats")
switch tt.fileShouldNotChange {
case true:
require.Equal(t, previousStats.ModTime(), newStats.ModTime(), "file should not change")
case false:
require.NotEqual(t, previousStats.ModTime(), newStats.ModTime(), "file should have changed")
newContent, err := os.ReadFile(configPath)
require.NoError(t, err, "failed to read updated config")
if tt.fileShouldNotChange {
require.Equal(t, string(previousContent), string(newContent), "file should not change")
} else {
require.NotEqual(t, string(previousContent), string(newContent), "file should have changed")
}
})
}

View File

@@ -52,6 +52,7 @@ type Manager interface {
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetSelectedClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
@@ -465,6 +466,16 @@ func (m *DefaultManager) GetClientRoutes() route.HAMap {
return maps.Clone(m.clientRoutes)
}
// GetSelectedClientRoutes returns only the currently selected/active client routes,
// filtering out deselected exit nodes. Use this instead of GetClientRoutes when checking
// if traffic should be routed through the tunnel.
func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
m.mux.Lock()
defer m.mux.Unlock()
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()

View File

@@ -18,6 +18,7 @@ type MockManager struct {
TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetSelectedClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager)
}
@@ -61,7 +62,7 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
return nil
}
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
// GetClientRoutes mock implementation of GetClientRoutes from the Manager interface
func (m *MockManager) GetClientRoutes() route.HAMap {
if m.GetClientRoutesFunc != nil {
return m.GetClientRoutesFunc()
@@ -69,6 +70,14 @@ func (m *MockManager) GetClientRoutes() route.HAMap {
return nil
}
// GetSelectedClientRoutes mock implementation of GetSelectedClientRoutes from the Manager interface
func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
if m.GetSelectedClientRoutesFunc != nil {
return m.GetSelectedClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {

View File

@@ -31,26 +31,11 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listener = listener
}
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
// and a separate comparison set (without fake IP blocks) for diff detection.
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
// initialRoutes contains fake IP block for interface configuration
filteredInitial := make([]*route.Route, 0)
for _, r := range initialRoutes {
if r.IsDynamic() {
continue
}
filteredInitial = append(filteredInitial, r)
}
n.initialRoutes = filteredInitial
// routesForComparison excludes fake IP block for comparison with new routes
filteredComparison := make([]*route.Route, 0)
for _, r := range routesForComparison {
if r.IsDynamic() {
continue
}
filteredComparison = append(filteredComparison, r)
}
n.currentRoutes = filteredComparison
n.initialRoutes = filterStatic(initialRoutes)
n.currentRoutes = filterStatic(routesForComparison)
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
@@ -83,13 +68,43 @@ func (n *Notifier) notify() {
return
}
routeStrings := n.routesToStrings(n.currentRoutes)
allRoutes := slices.Clone(n.currentRoutes)
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
routeStrings := n.routesToStrings(allRoutes)
sort.Strings(routeStrings)
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ","))
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, allRoutes), ","))
}(n.listener)
}
// extraInitialRoutes returns initialRoutes whose network prefix is absent
// from currentRoutes (e.g. the fake IP block added at setup time).
func (n *Notifier) extraInitialRoutes() []*route.Route {
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
for _, r := range n.currentRoutes {
currentNets[r.Network] = struct{}{}
}
var extra []*route.Route
for _, r := range n.initialRoutes {
if _, ok := currentNets[r.Network]; !ok {
extra = append(extra, r)
}
}
return extra
}
func filterStatic(routes []*route.Route) []*route.Route {
out := make([]*route.Route, 0, len(routes))
for _, r := range routes {
if !r.IsDynamic() {
out = append(out, r)
}
}
return out
}
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
nets := make([]string, 0, len(routes))
for _, r := range routes {

View File

@@ -1360,6 +1360,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
}
if engine.IsBlockInbound() {
return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first")
}
mgr := engine.GetExposeManager()
if mgr == nil {
return gstatus.Errorf(codes.Internal, "expose manager not available")

View File

@@ -314,13 +314,17 @@ func (s *Server) closeListener(ln net.Listener) {
// Stop closes the SSH server
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.sshServer == nil {
sshServer := s.sshServer
if sshServer == nil {
s.mu.Unlock()
return nil
}
s.sshServer = nil
s.listener = nil
s.mu.Unlock()
if err := s.sshServer.Close(); err != nil {
// Close outside the lock: session handlers need s.mu for unregisterSession.
if err := sshServer.Close(); err != nil {
log.Debugf("close SSH server: %v", err)
}
@@ -334,6 +338,7 @@ func (s *Server) Stop() error {
s.sshServer = nil
s.listener = nil
s.mu.Lock()
maps.Clear(s.sessions)
maps.Clear(s.pendingAuthJWT)
maps.Clear(s.connections)
@@ -344,6 +349,7 @@ func (s *Server) Stop() error {
}
}
maps.Clear(s.remoteForwardListeners)
s.mu.Unlock()
return nil
}

View File

@@ -155,6 +155,9 @@ func networkAddresses() ([]NetworkAddress, error) {
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}

View File

@@ -326,6 +326,7 @@ type serviceClient struct {
exitNodeMu sync.Mutex
mExitNodeItems []menuHandler
exitNodeRetryCancel context.CancelFunc
mExitNodeSeparator *systray.MenuItem
mExitNodeDeselectAll *systray.MenuItem
logFile string
wLoginURL fyne.Window

View File

@@ -421,6 +421,10 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
node.Remove()
}
s.mExitNodeItems = nil
if s.mExitNodeSeparator != nil {
s.mExitNodeSeparator.Remove()
s.mExitNodeSeparator = nil
}
if s.mExitNodeDeselectAll != nil {
s.mExitNodeDeselectAll.Remove()
s.mExitNodeDeselectAll = nil
@@ -453,31 +457,37 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
}
if showDeselectAll {
s.mExitNode.AddSeparator()
deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All")
s.mExitNodeDeselectAll = deselectAllItem
go func() {
for {
_, ok := <-deselectAllItem.ClickedCh
if !ok {
// channel closed: exit the goroutine
return
}
exitNodes, err := s.handleExitNodeMenuDeselectAll()
if err != nil {
log.Warnf("failed to handle deselect all exit nodes: %v", err)
} else {
s.exitNodeMu.Lock()
s.recreateExitNodeMenu(exitNodes)
s.exitNodeMu.Unlock()
}
}
}()
s.addExitNodeDeselectAll()
}
}
func (s *serviceClient) addExitNodeDeselectAll() {
sep := s.mExitNode.AddSubMenuItem("───────────────", "")
sep.Disable()
s.mExitNodeSeparator = sep
deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All")
s.mExitNodeDeselectAll = deselectAllItem
go func() {
for {
_, ok := <-deselectAllItem.ClickedCh
if !ok {
return
}
exitNodes, err := s.handleExitNodeMenuDeselectAll()
if err != nil {
log.Warnf("failed to handle deselect all exit nodes: %v", err)
} else {
s.exitNodeMu.Lock()
s.recreateExitNodeMenu(exitNodes)
s.exitNodeMu.Unlock()
}
}
}()
}
func (s *serviceClient) getExitNodes(conn proto.DaemonServiceClient) ([]*proto.Network, error) {
ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout)
defer cancel()

43
go.mod
View File

@@ -17,13 +17,13 @@ require (
github.com/spf13/cobra v1.10.1
github.com/spf13/pflag v1.0.9
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.46.0
golang.org/x/sys v0.39.0
golang.org/x/crypto v0.48.0
golang.org/x/sys v0.41.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.77.0
google.golang.org/protobuf v1.36.10
google.golang.org/grpc v1.79.3
google.golang.org/protobuf v1.36.11
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
@@ -49,6 +49,7 @@ require (
github.com/eko/gocache/store/redis/v4 v4.2.2
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.3
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/golang/mock v1.6.0
@@ -101,21 +102,21 @@ require (
github.com/vmihailenco/msgpack/v5 v5.4.1
github.com/yusufpapurcu/wmi v1.2.4
github.com/zcalusic/sysinfo v1.1.3
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0
go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.38.0
go.opentelemetry.io/otel/sdk/metric v1.38.0
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0
go.opentelemetry.io/otel v1.42.0
go.opentelemetry.io/otel/exporters/prometheus v0.64.0
go.opentelemetry.io/otel/metric v1.42.0
go.opentelemetry.io/otel/sdk/metric v1.42.0
go.uber.org/mock v0.5.2
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
golang.org/x/mod v0.30.0
golang.org/x/net v0.47.0
golang.org/x/mod v0.32.0
golang.org/x/net v0.51.0
golang.org/x/oauth2 v0.34.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.38.0
golang.org/x/term v0.40.0
golang.org/x/time v0.14.0
google.golang.org/api v0.257.0
gopkg.in/yaml.v3 v3.0.1
@@ -181,7 +182,6 @@ require (
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/go-ldap/ldap/v3 v3.4.12 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
@@ -249,8 +249,9 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/otlptranslator v1.0.0 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/russellhaering/goxmldsig v1.5.0 // indirect
github.com/rymdport/portal v0.4.2 // indirect
github.com/shirou/gopsutil/v4 v4.25.1 // indirect
@@ -269,15 +270,15 @@ require (
github.com/zeebo/blake3 v0.2.3 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
go.opentelemetry.io/otel/sdk v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
go.opentelemetry.io/otel/sdk v1.42.0 // indirect
go.opentelemetry.io/otel/trace v1.42.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/image v0.33.0 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/tools v0.39.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.41.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
)

86
go.sum
View File

@@ -487,10 +487,12 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEoIwkU+A6qos=
github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM=
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk=
github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
@@ -603,26 +605,26 @@ github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo=
github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.opentelemetry.io/otel/exporters/prometheus v0.64.0 h1:g0LRDXMX/G1SEZtK8zl8Chm4K6GBwRkjPKE36LxiTYs=
go.opentelemetry.io/otel/exporters/prometheus v0.64.0/go.mod h1:UrgcjnarfdlBDP3GjDIJWe6HTprwSazNjwsI+Ru6hro=
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
@@ -633,8 +635,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -648,8 +650,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
@@ -666,8 +668,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
@@ -686,8 +688,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
@@ -738,8 +740,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -752,8 +754,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -765,8 +767,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -780,8 +782,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -799,12 +801,12 @@ google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s=
google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4=
google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 h1:Wgl1rcDNThT+Zn47YyCXOXyX/COgMTIdhJ717F0l4xk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
@@ -815,8 +817,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=

View File

@@ -170,20 +170,66 @@ type Connector struct {
}
// ToStorageConnector converts a Connector to storage.Connector type.
// It maps custom connector types (e.g., "zitadel", "entra") to Dex-native types
// and augments the config with OIDC defaults when needed.
func (c *Connector) ToStorageConnector() (storage.Connector, error) {
data, err := json.Marshal(c.Config)
dexType, augmentedConfig := mapConnectorToDex(c.Type, c.Config)
data, err := json.Marshal(augmentedConfig)
if err != nil {
return storage.Connector{}, fmt.Errorf("failed to marshal connector config: %v", err)
}
return storage.Connector{
ID: c.ID,
Type: c.Type,
Type: dexType,
Name: c.Name,
Config: data,
}, nil
}
// mapConnectorToDex maps custom connector types to Dex-native types and applies
// OIDC defaults. This ensures static connectors from config files or env vars
// are stored with types that Dex can open.
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
switch connType {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
return "oidc", applyOIDCDefaults(connType, config)
default:
return connType, config
}
}
// applyOIDCDefaults clones the config map, sets common OIDC defaults,
// and applies provider-specific overrides.
func applyOIDCDefaults(connType string, config map[string]interface{}) map[string]interface{} {
augmented := make(map[string]interface{}, len(config)+4)
for k, v := range config {
augmented[k] = v
}
setDefault(augmented, "scopes", []string{"openid", "profile", "email"})
setDefault(augmented, "insecureEnableGroups", true)
setDefault(augmented, "insecureSkipEmailVerified", true)
switch connType {
case "zitadel":
setDefault(augmented, "getUserInfo", true)
case "entra":
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
case "okta", "pocketid":
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
}
return augmented
}
// setDefault sets a key in the map only if it doesn't already exist.
func setDefault(m map[string]interface{}, key string, value interface{}) {
if _, ok := m[key]; !ok {
m[key] = value
}
}
// StorageConfig is a configuration that can create a storage.
type StorageConfig interface {
Open(logger *slog.Logger) (storage.Storage, error)

View File

@@ -4,6 +4,7 @@ package dex
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
@@ -19,10 +20,13 @@ import (
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
jose "github.com/go-jose/go-jose/v4"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
// Config matches what management/internals/server/server.go expects
@@ -666,3 +670,46 @@ func (p *Provider) GetAuthorizationEndpoint() string {
}
return issuer + "/auth"
}
// GetJWKS reads signing keys directly from Dex storage and returns them as Jwks.
// This avoids HTTP round-trips when the embedded IDP is co-located with the management server.
// The key retrieval mirrors Dex's own handlePublicKeys/ValidationKeys logic:
// SigningKeyPub first, then all VerificationKeys, serialized via go-jose.
func (p *Provider) GetJWKS(ctx context.Context) (*nbjwt.Jwks, error) {
keys, err := p.storage.GetKeys(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get keys from storage: %w", err)
}
if keys.SigningKeyPub == nil {
return nil, fmt.Errorf("no public keys found in storage")
}
// Build the key set exactly as Dex's localSigner.ValidationKeys does:
// signing key first, then all verification (rotated) keys.
joseKeys := make([]jose.JSONWebKey, 0, len(keys.VerificationKeys)+1)
joseKeys = append(joseKeys, *keys.SigningKeyPub)
for _, vk := range keys.VerificationKeys {
if vk.PublicKey != nil {
joseKeys = append(joseKeys, *vk.PublicKey)
}
}
// Serialize through go-jose (same as Dex's handlePublicKeys handler)
// then deserialize into our Jwks type, so the JSON field mapping is identical
// to what the /keys HTTP endpoint would return.
joseSet := jose.JSONWebKeySet{Keys: joseKeys}
data, err := json.Marshal(joseSet)
if err != nil {
return nil, fmt.Errorf("failed to marshal JWKS: %w", err)
}
jwks := &nbjwt.Jwks{}
if err := json.Unmarshal(data, jwks); err != nil {
return nil, fmt.Errorf("failed to unmarshal JWKS: %w", err)
}
jwks.ExpiresInTime = keys.NextRotation
return jwks, nil
}

View File

@@ -2,11 +2,14 @@ package dex
import (
"context"
"encoding/json"
"log/slog"
"os"
"path/filepath"
"testing"
"github.com/dexidp/dex/storage"
sqllib "github.com/dexidp/dex/storage/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -197,6 +200,295 @@ enablePasswordDB: true
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
}
// openTestStorage creates a SQLite storage in the given directory for testing.
func openTestStorage(t *testing.T, tmpDir string) storage.Storage {
t.Helper()
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
stor, err := (&sqllib.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
require.NoError(t, err)
return stor
}
func TestStaticConnectors_CreatedFromYAML(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-static-conn-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
yamlContent := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + filepath.Join(tmpDir, "dex.db") + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
connectors:
- type: oidc
id: my-oidc
name: My OIDC Provider
config:
issuer: https://accounts.example.com
clientID: test-client-id
clientSecret: test-client-secret
redirectURI: http://localhost:5556/dex/callback
`
configPath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
yamlConfig, err := LoadConfig(configPath)
require.NoError(t, err)
// Open storage and run initializeStorage directly (avoids Dex server
// trying to dial the OIDC issuer)
stor := openTestStorage(t, tmpDir)
defer stor.Close()
err = initializeStorage(ctx, stor, yamlConfig)
require.NoError(t, err)
// Verify connector was created in storage
conn, err := stor.GetConnector(ctx, "my-oidc")
require.NoError(t, err)
assert.Equal(t, "my-oidc", conn.ID)
assert.Equal(t, "My OIDC Provider", conn.Name)
assert.Equal(t, "oidc", conn.Type)
// Verify config fields were serialized correctly
var configMap map[string]interface{}
err = json.Unmarshal(conn.Config, &configMap)
require.NoError(t, err)
assert.Equal(t, "https://accounts.example.com", configMap["issuer"])
assert.Equal(t, "test-client-id", configMap["clientID"])
}
func TestStaticConnectors_UpdatedOnRestart(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-static-conn-update-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
dbFile := filepath.Join(tmpDir, "dex.db")
// First: load config with initial connector
yamlContent1 := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + dbFile + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
connectors:
- type: oidc
id: my-oidc
name: Original Name
config:
issuer: https://accounts.example.com
clientID: original-client-id
clientSecret: original-secret
`
configPath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configPath, []byte(yamlContent1), 0644)
require.NoError(t, err)
yamlConfig1, err := LoadConfig(configPath)
require.NoError(t, err)
stor := openTestStorage(t, tmpDir)
err = initializeStorage(ctx, stor, yamlConfig1)
require.NoError(t, err)
// Verify initial state
conn, err := stor.GetConnector(ctx, "my-oidc")
require.NoError(t, err)
assert.Equal(t, "Original Name", conn.Name)
var configMap1 map[string]interface{}
err = json.Unmarshal(conn.Config, &configMap1)
require.NoError(t, err)
assert.Equal(t, "original-client-id", configMap1["clientID"])
// Close storage to simulate restart
stor.Close()
// Second: load updated config against the same DB
yamlContent2 := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + dbFile + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
connectors:
- type: oidc
id: my-oidc
name: Updated Name
config:
issuer: https://accounts.example.com
clientID: updated-client-id
clientSecret: updated-secret
`
err = os.WriteFile(configPath, []byte(yamlContent2), 0644)
require.NoError(t, err)
yamlConfig2, err := LoadConfig(configPath)
require.NoError(t, err)
stor2 := openTestStorage(t, tmpDir)
defer stor2.Close()
err = initializeStorage(ctx, stor2, yamlConfig2)
require.NoError(t, err)
// Verify connector was updated, not duplicated
allConnectors, err := stor2.ListConnectors(ctx)
require.NoError(t, err)
nonLocalCount := 0
for _, c := range allConnectors {
if c.ID != "local" {
nonLocalCount++
}
}
assert.Equal(t, 1, nonLocalCount, "connector should be updated, not duplicated")
conn2, err := stor2.GetConnector(ctx, "my-oidc")
require.NoError(t, err)
assert.Equal(t, "Updated Name", conn2.Name)
var configMap2 map[string]interface{}
err = json.Unmarshal(conn2.Config, &configMap2)
require.NoError(t, err)
assert.Equal(t, "updated-client-id", configMap2["clientID"])
}
func TestStaticConnectors_MultipleConnectors(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-static-conn-multi-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
yamlContent := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + filepath.Join(tmpDir, "dex.db") + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
connectors:
- type: oidc
id: my-oidc
name: My OIDC Provider
config:
issuer: https://accounts.example.com
clientID: oidc-client-id
clientSecret: oidc-secret
- type: google
id: my-google
name: Google Login
config:
clientID: google-client-id
clientSecret: google-secret
`
configPath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
yamlConfig, err := LoadConfig(configPath)
require.NoError(t, err)
stor := openTestStorage(t, tmpDir)
defer stor.Close()
err = initializeStorage(ctx, stor, yamlConfig)
require.NoError(t, err)
allConnectors, err := stor.ListConnectors(ctx)
require.NoError(t, err)
// Build a map for easier assertion
connByID := make(map[string]storage.Connector)
for _, c := range allConnectors {
connByID[c.ID] = c
}
// Verify both static connectors exist
oidcConn, ok := connByID["my-oidc"]
require.True(t, ok, "oidc connector should exist")
assert.Equal(t, "My OIDC Provider", oidcConn.Name)
assert.Equal(t, "oidc", oidcConn.Type)
var oidcConfig map[string]interface{}
err = json.Unmarshal(oidcConn.Config, &oidcConfig)
require.NoError(t, err)
assert.Equal(t, "oidc-client-id", oidcConfig["clientID"])
googleConn, ok := connByID["my-google"]
require.True(t, ok, "google connector should exist")
assert.Equal(t, "Google Login", googleConn.Name)
assert.Equal(t, "google", googleConn.Type)
var googleConfig map[string]interface{}
err = json.Unmarshal(googleConn.Config, &googleConfig)
require.NoError(t, err)
assert.Equal(t, "google-client-id", googleConfig["clientID"])
// Verify local connector still exists alongside them (enablePasswordDB: true)
localConn, ok := connByID["local"]
require.True(t, ok, "local connector should exist")
assert.Equal(t, "local", localConn.Type)
}
func TestStaticConnectors_EmptyList(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-static-conn-empty-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
yamlContent := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + filepath.Join(tmpDir, "dex.db") + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
`
configPath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
yamlConfig, err := LoadConfig(configPath)
require.NoError(t, err)
provider, err := NewProviderFromYAML(ctx, yamlConfig)
require.NoError(t, err)
defer func() { _ = provider.Stop(ctx) }()
// No static connectors configured, so ListConnectors should return empty
connectors, err := provider.ListConnectors(ctx)
require.NoError(t, err)
assert.Empty(t, connectors)
// But local connector should still exist
localConn, err := provider.Storage().GetConnector(ctx, "local")
require.NoError(t, err)
assert.Equal(t, "local", localConn.ID)
}
func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) {
ctx := context.Background()

View File

@@ -172,8 +172,11 @@ init_environment() {
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
echo ""
echo "Login with the following credentials:"
echo "Email: admin@$NETBIRD_DOMAIN" | tee .env
echo "Password: $NETBIRD_ADMIN_PASSWORD" | tee -a .env
install -m 600 /dev/null .env
printf 'Email: admin@%s\nPassword: %s\n' \
"$NETBIRD_DOMAIN" "$NETBIRD_ADMIN_PASSWORD" >> .env
echo "Email: admin@$NETBIRD_DOMAIN"
echo "Password: $NETBIRD_ADMIN_PASSWORD"
echo ""
echo "Dex admin UI is not available (Dex has no built-in UI)."
echo "To add more users, edit dex.yaml and restart: $DOCKER_COMPOSE_COMMAND restart dex"

View File

@@ -563,8 +563,11 @@ initEnvironment() {
echo -e "\nDone!\n"
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
echo "Login with the following credentials:"
echo "Username: $ZITADEL_ADMIN_USERNAME" | tee .env
echo "Password: $ZITADEL_ADMIN_PASSWORD" | tee -a .env
install -m 600 /dev/null .env
printf 'Username: %s\nPassword: %s\n' \
"$ZITADEL_ADMIN_USERNAME" "$ZITADEL_ADMIN_PASSWORD" >> .env
echo "Username: $ZITADEL_ADMIN_USERNAME"
echo "Password: $ZITADEL_ADMIN_PASSWORD"
}
renderCaddyfile() {

View File

@@ -1154,7 +1154,16 @@ print_builtin_traefik_instructions() {
echo " - $NETBIRD_STUN_PORT/udp (STUN - required for NAT traversal)"
if [[ "$ENABLE_PROXY" == "true" ]]; then
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
echo ""
fi
echo ""
echo "This setup is ideal for homelabs and smaller organization deployments."
echo "For enterprise environments requiring high availability and advanced integrations,"
echo "consider a commercial on-prem license or scaling your open source deployment:"
echo ""
echo " Commercial license: https://netbird.io/pricing#on-prem"
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
echo ""
if [[ "$ENABLE_PROXY" == "true" ]]; then
echo "NetBird Proxy:"
echo " The proxy service is enabled and running."
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."

View File

@@ -154,9 +154,11 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
return err
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") {
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
}
return nil
})

View File

@@ -31,19 +31,15 @@ type store interface {
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
}
type clusterCapabilities interface {
ClusterSupportsCustomPorts(clusterAddr string) *bool
ClusterRequireSubdomain(clusterAddr string) *bool
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
}
type Manager struct {
store store
validator domain.Validator
proxyManager proxyManager
clusterCapabilities clusterCapabilities
permissionsManager permissions.Manager
store store
validator domain.Validator
proxyManager proxyManager
permissionsManager permissions.Manager
accountManager account.Manager
}
@@ -57,11 +53,6 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio
}
}
// SetClusterCapabilities sets the cluster capabilities provider for domain queries.
func (m *Manager) SetClusterCapabilities(caps clusterCapabilities) {
m.clusterCapabilities = caps
}
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
@@ -97,10 +88,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
Type: domain.TypeFree,
Validated: true,
}
if m.clusterCapabilities != nil {
d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster)
d.RequireSubdomain = m.clusterCapabilities.ClusterRequireSubdomain(cluster)
}
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
ret = append(ret, d)
}
@@ -114,8 +103,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
Type: domain.TypeCustom,
Validated: d.Validated,
}
if m.clusterCapabilities != nil && d.TargetCluster != "" {
cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster)
if d.TargetCluster != "" {
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
}
// Custom domains never require a subdomain by default since
// the account owns them and should be able to use the bare domain.

View File

@@ -11,11 +11,13 @@ import (
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
}
@@ -34,6 +36,4 @@ type Controller interface {
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
ClusterSupportsCustomPorts(clusterAddr string) *bool
ClusterRequireSubdomain(clusterAddr string) *bool
}

View File

@@ -72,17 +72,6 @@ func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, cluster
return nil
}
// ClusterSupportsCustomPorts returns whether any proxy in the cluster supports custom ports.
func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr)
}
// ClusterRequireSubdomain returns whether the cluster requires a subdomain label.
// Returns nil when no proxy has reported the capability (defaults to false).
func (c *GRPCController) ClusterRequireSubdomain(clusterAddr string) *bool {
return c.proxyGRPCServer.ClusterRequireSubdomain(clusterAddr)
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)

View File

@@ -16,6 +16,8 @@ type store interface {
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
}
@@ -38,9 +40,14 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
}, nil
}
// Connect registers a new proxy connection in the database
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
// Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
now := time.Now()
var caps proxy.Capabilities
if capabilities != nil {
caps = *capabilities
}
p := &proxy.Proxy{
ID: proxyID,
ClusterAddress: clusterAddress,
@@ -48,6 +55,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
Capabilities: caps,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
@@ -118,6 +126,18 @@ func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error)
return clusters, nil
}
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
// supports custom ports. Returns nil when no proxy has reported capabilities.
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
return m.store.GetClusterSupportsCustomPorts(ctx, clusterAddr)
}
// ClusterRequireSubdomain returns whether any active proxy in the cluster
// requires a subdomain. Returns nil when no proxy has reported capabilities.
func (m Manager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
return m.store.GetClusterRequireSubdomain(ctx, clusterAddr)
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {

View File

@@ -50,18 +50,46 @@ func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
// ClusterSupportsCustomPorts mocks base method.
func (m *MockManager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts.
func (mr *MockManagerMockRecorder) ClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCustomPorts), ctx, clusterAddr)
}
// ClusterRequireSubdomain mocks base method.
func (m *MockManager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterRequireSubdomain", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain.
func (mr *MockManagerMockRecorder) ClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockManager)(nil).ClusterRequireSubdomain), ctx, clusterAddr)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
ret0, _ := ret[0].(error)
return ret0
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
}
// Disconnect mocks base method.
@@ -145,34 +173,6 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder {
return m.recorder
}
// ClusterSupportsCustomPorts mocks base method.
func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts.
func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr)
}
// ClusterRequireSubdomain mocks base method.
func (m *MockController) ClusterRequireSubdomain(clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterRequireSubdomain", clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain.
func (mr *MockControllerMockRecorder) ClusterRequireSubdomain(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockController)(nil).ClusterRequireSubdomain), clusterAddr)
}
// GetOIDCValidationConfig mocks base method.
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()

View File

@@ -2,6 +2,17 @@ package proxy
import "time"
// Capabilities describes what a proxy can handle, as reported via gRPC.
// Nil fields mean the proxy never reported this capability.
type Capabilities struct {
// SupportsCustomPorts indicates whether this proxy can bind arbitrary
// ports for TCP/UDP services. TLS uses SNI routing and is not gated.
SupportsCustomPorts *bool
// RequireSubdomain indicates whether a subdomain label is required in
// front of the cluster domain.
RequireSubdomain *bool
}
// Proxy represents a reverse proxy instance
type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"`
@@ -11,6 +22,7 @@ type Proxy struct {
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Capabilities Capabilities `gorm:"embedded"`
CreatedAt time.Time
UpdatedAt time.Time
}

View File

@@ -75,16 +75,18 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
require.NoError(t, err)
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes()
mockCtrl.EXPECT().ClusterRequireSubdomain(gomock.Any()).Return((*bool)(nil)).AnyTimes()
mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
mockCaps := proxy.NewMockManager(ctrl)
mockCaps.EXPECT().ClusterSupportsCustomPorts(gomock.Any(), testCluster).Return(customPortsSupported).AnyTimes()
mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes()
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
},
}
@@ -93,6 +95,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
accountManager: accountMgr,
permissionsManager: permissions.NewManager(testStore),
proxyController: mockCtrl,
capabilities: mockCaps,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
}
mgr.exposeReaper = &exposeReaper{manager: mgr}

View File

@@ -75,22 +75,30 @@ type ClusterDeriver interface {
GetClusterDomains() []string
}
// CapabilityProvider queries proxy cluster capabilities from the database.
type CapabilityProvider interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
}
type Manager struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyController proxy.Controller
capabilities CapabilityProvider
clusterDeriver ClusterDeriver
exposeReaper *exposeReaper
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
mgr := &Manager{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyController: proxyController,
capabilities: capabilities,
clusterDeriver: clusterDeriver,
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
@@ -237,7 +245,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
}
service.ProxyCluster = proxyCluster
if err := m.validateSubdomainRequirement(service.Domain, proxyCluster); err != nil {
if err := m.validateSubdomainRequirement(ctx, service.Domain, proxyCluster); err != nil {
return err
}
}
@@ -268,11 +276,11 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
// validateSubdomainRequirement checks whether the domain can be used bare
// (without a subdomain label) on the given cluster. If the cluster reports
// require_subdomain=true and the domain equals the cluster domain, it rejects.
func (m *Manager) validateSubdomainRequirement(domain, cluster string) error {
func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, cluster string) error {
if domain != cluster {
return nil
}
requireSub := m.proxyController.ClusterRequireSubdomain(cluster)
requireSub := m.capabilities.ClusterRequireSubdomain(ctx, cluster)
if requireSub != nil && *requireSub {
return status.Errorf(status.InvalidArgument, "domain %s requires a subdomain label", domain)
}
@@ -280,6 +288,8 @@ func (m *Manager) validateSubdomainRequirement(domain, cluster string) error {
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if svc.Domain != "" {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
@@ -287,7 +297,7 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
}
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
return err
}
@@ -307,12 +317,23 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
})
}
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error {
// clusterCustomPorts queries whether the cluster supports custom ports.
// Must be called before entering a transaction: the underlying query uses
// the main DB handle, which deadlocks when called inside a transaction
// that already holds the connection.
func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service) *bool {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
return m.capabilities.ClusterSupportsCustomPorts(ctx, svc.ProxyCluster)
}
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
// customPorts must be pre-computed via clusterCustomPorts before entering a transaction.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster)
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
@@ -396,12 +417,14 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
return err
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
return err
}
@@ -504,21 +527,58 @@ type serviceUpdateInfo struct {
}
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
effectiveCluster, err := m.resolveEffectiveCluster(ctx, accountID, service)
if err != nil {
return nil, err
}
svcForCaps := *service
svcForCaps.ProxyCluster = effectiveCluster
customPorts := m.clusterCustomPorts(ctx, &svcForCaps)
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts)
})
return &updateInfo, err
}
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo) error {
// resolveEffectiveCluster determines the cluster that will be used after the update.
// It reads the existing service without locking and derives the new cluster if the domain changed.
func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string, svc *service.Service) (string, error) {
existing, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, svc.ID)
if err != nil {
return "", err
}
if existing.Domain == svc.Domain {
return existing.ProxyCluster, nil
}
if m.clusterDeriver != nil {
derived, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
} else {
return derived, nil
}
}
return existing.ProxyCluster, nil
}
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
if existingService.Terminated {
return status.Errorf(status.PermissionDenied, "service is terminated and cannot be updated")
}
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
return err
}
@@ -534,7 +594,7 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
service.ProxyCluster = existingService.ProxyCluster
}
if err := m.validateSubdomainRequirement(service.Domain, service.ProxyCluster); err != nil {
if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil {
return err
}
@@ -546,7 +606,7 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
@@ -1059,7 +1119,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
}
groupIDs := make([]string, 0, len(groupNames))
for _, groupName := range groupNames {
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator)
if err != nil {
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
}

View File

@@ -698,8 +698,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
},
}
@@ -1324,11 +1324,11 @@ func TestValidateSubdomainRequirement(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterRequireSubdomain(tc.cluster).Return(tc.requireSubdomain).AnyTimes()
mockCaps := proxy.NewMockManager(ctrl)
mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), tc.cluster).Return(tc.requireSubdomain).AnyTimes()
mgr := &Manager{proxyController: mockCtrl}
err := mgr.validateSubdomainRequirement(tc.domain, tc.cluster)
mgr := &Manager{capabilities: mockCaps}
err := mgr.validateSubdomainRequirement(context.Background(), tc.domain, tc.cluster)
if tc.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "requires a subdomain label")

View File

@@ -184,6 +184,7 @@ type Service struct {
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool
Terminated bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
@@ -256,7 +257,7 @@ func (s *Service) ToAPIResponse() *api.Service {
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
Enabled: target.Enabled && !s.Terminated,
}
opts := targetOptionsToAPI(target.Options)
if opts == nil {
@@ -286,7 +287,8 @@ func (s *Service) ToAPIResponse() *api.Service {
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
Enabled: s.Enabled && !s.Terminated,
Terminated: &s.Terminated,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
@@ -785,6 +787,11 @@ func (s *Service) validateHTTPTargets() error {
}
func (s *Service) validateL4Target(target *Target) error {
// L4 services have a single target; per-target disable is meaningless
// (use the service-level Enabled flag instead). Force it on so that
// buildPathMappings always includes the target in the proto.
target.Enabled = true
if target.Port == 0 {
return errors.New("target port is required for L4 services")
}
@@ -1125,6 +1132,7 @@ func (s *Service) Copy() *Service {
ProxyCluster: s.ProxyCluster,
Targets: targets,
Enabled: s.Enabled,
Terminated: s.Terminated,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: authCopy,

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
@@ -71,6 +72,7 @@ func (s *BaseServer) AuthManager() auth.Manager {
signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled
issuer := s.Config.HttpConfig.AuthIssuer
userIDClaim := s.Config.HttpConfig.AuthUserIDClaim
var keyFetcher nbjwt.KeyFetcher
// Use embedded IdP configuration if available
if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil {
@@ -78,8 +80,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
if len(audiences) > 0 {
audience = audiences[0] // Use the first client ID as the primary audience
}
// Use localhost keys location for internal validation (management has embedded Dex)
keysLocation = oauthProvider.GetLocalKeysLocation()
keyFetcher = oauthProvider.GetKeyFetcher()
// Fall back to default keys location if direct key fetching is not available
if keyFetcher == nil {
keysLocation = oauthProvider.GetLocalKeysLocation()
}
signingKeyRefreshEnabled = true
issuer = oauthProvider.GetIssuer()
userIDClaim = oauthProvider.GetUserIDClaim()
@@ -92,7 +97,8 @@ func (s *BaseServer) AuthManager() auth.Manager {
keysLocation,
userIDClaim,
audiences,
signingKeyRefreshEnabled)
signingKeyRefreshEnabled,
keyFetcher)
})
}

View File

@@ -117,9 +117,11 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
// Use embedded IdP service if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured.
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
embeddedEnabled := s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled
if embeddedEnabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil {
log.Fatalf("failed to create embedded IDP service: %v", err)
@@ -195,7 +197,7 @@ func (s *BaseServer) RecordsManager() records.Manager {
func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
})
}
@@ -212,9 +214,6 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
s.AfterInit(func(s *BaseServer) {
m.SetClusterCapabilities(s.ServiceProxyController())
})
return &m
})
}

View File

@@ -182,9 +182,21 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
// Register proxy in database with capabilities
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
caps = &proxy.Capabilities{
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
}
}
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
s.connectedProxies.Delete(proxyID)
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
}
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
log.WithFields(log.Fields{
@@ -297,6 +309,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
}
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
if !proxyAcceptsMapping(conn, m) {
continue
}
mappings = append(mappings, m)
}
return mappings, nil
@@ -445,22 +460,46 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
log.Debugf("Sending service update to cluster %s", clusterAddr)
for _, proxyID := range proxyIDs {
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil {
continue
}
select {
case conn.sendChan <- msg:
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
connVal, ok := s.connectedProxies.Load(proxyID)
if !ok {
continue
}
conn := connVal.(*proxyConnection)
if !proxyAcceptsMapping(conn, update) {
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
continue
}
msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil {
continue
}
select {
case conn.sendChan <- msg:
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
}
}
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
// Old proxies that never reported capabilities are skipped for non-TLS L4
// mappings with a custom listen port, since they don't understand the
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
// are new enough to handle the mapping. TLS uses SNI routing and works on
// any proxy. Delete operations are always sent so proxies can clean up.
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
return true
}
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
return true
}
// Old proxies that never reported capabilities don't understand
// custom port mappings.
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
}
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
@@ -508,64 +547,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
}
}
// ClusterSupportsCustomPorts returns whether any connected proxy in the given
// cluster reports custom port support. Returns nil if no proxy has reported
// capabilities (old proxies that predate the field).
func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool {
if s.proxyController == nil {
return nil
}
var hasCapabilities bool
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
connVal, ok := s.connectedProxies.Load(pid)
if !ok {
continue
}
conn := connVal.(*proxyConnection)
if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil {
continue
}
if *conn.capabilities.SupportsCustomPorts {
return ptr(true)
}
hasCapabilities = true
}
if hasCapabilities {
return ptr(false)
}
return nil
}
// ClusterRequireSubdomain returns whether any connected proxy in the given
// cluster reports that a subdomain is required. Returns nil if no proxy has
// reported the capability (defaults to not required).
func (s *ProxyServiceServer) ClusterRequireSubdomain(clusterAddr string) *bool {
if s.proxyController == nil {
return nil
}
var hasCapabilities bool
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
connVal, ok := s.connectedProxies.Load(pid)
if !ok {
continue
}
conn := connVal.(*proxyConnection)
if conn.capabilities == nil || conn.capabilities.RequireSubdomain == nil {
continue
}
if *conn.capabilities.RequireSubdomain {
return ptr(true)
}
hasCapabilities = true
}
if hasCapabilities {
return ptr(false)
}
return nil
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {

View File

@@ -53,14 +53,6 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus
return nil
}
func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool {
return ptr(true)
}
func (c *testProxyController) ClusterRequireSubdomain(_ string) *bool {
return nil
}
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
c.mu.Lock()
defer c.mu.Unlock()
@@ -355,14 +347,14 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
const cluster = "proxy.example.com"
// Proxy A supports custom ports.
chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
// Proxy B does NOT support custom ports (shared cloud proxy).
chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
// Modern proxy reports capabilities.
chModern := registerFakeProxyWithCaps(s, "proxy-modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
// Legacy proxy never reported capabilities (nil).
chLegacy := registerFakeProxy(s, "proxy-legacy", cluster)
ctx := context.Background()
// TLS passthrough works on all proxies regardless of custom port support.
// TLS passthrough with custom port: all proxies receive it (SNI routing).
tlsMapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-tls",
@@ -375,12 +367,26 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster)
msgA := drainMapping(chA)
msgB := drainMapping(chB)
assert.NotNil(t, msgA, "proxy-a should receive TLS mapping")
assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)")
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TLS mapping")
assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive TLS mapping (SNI works on all)")
// Send an HTTP mapping: both should receive it.
// TCP mapping with custom port: only modern proxy receives it.
tcpMapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-tcp",
AccountId: "account-1",
Domain: "db.example.com",
Mode: "tcp",
ListenPort: 5432,
Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}},
}
s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster)
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TCP custom-port mapping")
assert.Nil(t, drainMapping(chLegacy), "legacy proxy should NOT receive TCP custom-port mapping")
// HTTP mapping (no listen port): both receive it.
httpMapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-http",
@@ -391,10 +397,16 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
s.SendServiceUpdateToCluster(ctx, httpMapping, cluster)
msgA = drainMapping(chA)
msgB = drainMapping(chB)
assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping")
assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping")
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive HTTP mapping")
assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive HTTP mapping")
// Proxy that reports SupportsCustomPorts=false still receives custom-port
// mappings because it understands the protocol (it's new enough).
chNewNoCustom := registerFakeProxyWithCaps(s, "proxy-new-no-custom", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster)
assert.NotNil(t, drainMapping(chNewNoCustom), "new proxy with SupportsCustomPorts=false should still receive mapping")
}
func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
@@ -408,7 +420,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
const cluster = "proxy.example.com"
chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
// Legacy proxy (no capabilities) still receives TLS since it uses SNI.
chLegacy := registerFakeProxy(s, "proxy-legacy", cluster)
tlsMapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
@@ -421,8 +434,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster)
msg := drainMapping(chShared)
assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support")
msg := drainMapping(chLegacy)
assert.NotNil(t, msg, "legacy proxy should receive TLS mapping (SNI works without custom port support)")
}
// TestServiceModifyNotifications exercises every possible modification
@@ -589,7 +602,7 @@ func TestServiceModifyNotifications(t *testing.T) {
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
chLegacy := registerFakeProxy(s, "legacy", cluster)
// TLS passthrough works on all proxies regardless of custom port support
s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster)
@@ -608,7 +621,7 @@ func TestServiceModifyNotifications(t *testing.T) {
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
chLegacy := registerFakeProxy(s, "legacy", cluster)
mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
mapping.ListenPort = 0 // default port

View File

@@ -75,7 +75,7 @@ type Manager interface {
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error

View File

@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
}
// GetGroupByName mocks base method.
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
ret0, _ := ret[0].(*types.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
}
// GetIdentityProvider mocks base method.

View File

@@ -3138,7 +3138,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
if err != nil {
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil))
return manager, updateManager, nil
}

View File

@@ -0,0 +1,61 @@
package store
// This file contains migration-only methods on Store.
// They satisfy the migration.MigrationEventStore interface via duck typing.
// Delete this file when migration tooling is no longer needed.
import (
"context"
"fmt"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp/migration"
)
// CheckSchema verifies that all tables and columns required by the migration exist in the event database.
func (store *Store) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError {
migrator := store.db.Migrator()
var errs []migration.SchemaError
for _, check := range checks {
if !migrator.HasTable(check.Table) {
errs = append(errs, migration.SchemaError{Table: check.Table})
continue
}
for _, col := range check.Columns {
if !migrator.HasColumn(check.Table, col) {
errs = append(errs, migration.SchemaError{Table: check.Table, Column: col})
}
}
}
return errs
}
// UpdateUserID updates all references to oldUserID in events and deleted_users tables.
func (store *Store) UpdateUserID(ctx context.Context, oldUserID, newUserID string) error {
return store.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&activity.Event{}).
Where("initiator_id = ?", oldUserID).
Update("initiator_id", newUserID).Error; err != nil {
return fmt.Errorf("update events.initiator_id: %w", err)
}
if err := tx.Model(&activity.Event{}).
Where("target_id = ?", oldUserID).
Update("target_id", newUserID).Error; err != nil {
return fmt.Errorf("update events.target_id: %w", err)
}
// Raw exec: GORM can't update a PK via Model().Update()
if err := tx.Exec(
"UPDATE deleted_users SET id = ? WHERE id = ?", newUserID, oldUserID,
).Error; err != nil {
return fmt.Errorf("update deleted_users.id: %w", err)
}
return nil
})
}

View File

@@ -0,0 +1,161 @@
package store
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util/crypt"
)
func TestUpdateUserID(t *testing.T) {
ctx := context.Background()
newStore := func(t *testing.T) *Store {
t.Helper()
key, _ := crypt.GenerateKey()
s, err := NewSqlStore(ctx, t.TempDir(), key)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { s.Close(ctx) }) //nolint
return s
}
t.Run("updates initiator_id in events", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "old-user",
TargetID: "some-peer",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, "new-user", result[0].InitiatorID)
})
t.Run("updates target_id in events", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "some-admin",
TargetID: "old-user",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, "new-user", result[0].TargetID)
})
t.Run("updates deleted_users id", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
// Save an event with email/name meta to create a deleted_users row for "old-user"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "admin",
TargetID: "old-user",
AccountID: accountID,
Meta: map[string]any{
"email": "user@example.com",
"name": "Test User",
},
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
// Save another event referencing new-user with email/name meta.
// This should upsert (not conflict) because the PK was already migrated.
_, err = store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "admin",
TargetID: "new-user",
AccountID: accountID,
Meta: map[string]any{
"email": "user@example.com",
"name": "Test User",
},
})
assert.NoError(t, err)
// The deleted user info should be retrievable via Get (joined on target_id)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 2)
for _, ev := range result {
assert.Equal(t, "new-user", ev.TargetID)
}
})
t.Run("no-op when old user ID does not exist", func(t *testing.T) {
store := newStore(t)
err := store.UpdateUserID(ctx, "nonexistent-user", "new-user")
assert.NoError(t, err)
})
t.Run("only updates matching user leaves others unchanged", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user-a",
TargetID: "peer-1",
AccountID: accountID,
})
assert.NoError(t, err)
_, err = store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user-b",
TargetID: "peer-2",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "user-a", "user-a-new")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 2)
for _, ev := range result {
if ev.TargetID == "peer-1" {
assert.Equal(t, "user-a-new", ev.InitiatorID)
} else {
assert.Equal(t, "user-b", ev.InitiatorID)
}
}
})
}

View File

@@ -33,15 +33,20 @@ type manager struct {
extractor *nbjwt.ClaimsExtractor
}
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager {
// @note if invalid/missing parameters are sent the validator will instantiate
// but it will fail when validating and parsing the token
jwtValidator := nbjwt.NewValidator(
issuer,
allAudiences,
keysLocation,
idpRefreshKeys,
)
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool, keyFetcher nbjwt.KeyFetcher) Manager {
var jwtValidator *nbjwt.Validator
if keyFetcher != nil {
jwtValidator = nbjwt.NewValidatorWithKeyFetcher(issuer, allAudiences, keyFetcher)
} else {
// @note if invalid/missing parameters are sent the validator will instantiate
// but it will fail when validating and parsing the token
jwtValidator = nbjwt.NewValidator(
issuer,
allAudiences,
keysLocation,
idpRefreshKeys,
)
}
claimsExtractor := nbjwt.NewClaimsExtractor(
nbjwt.WithAudience(audience),

View File

@@ -52,7 +52,7 @@ func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
if err != nil {
@@ -92,7 +92,7 @@ func TestAuthManager_MarkPATUsed(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
err = manager.MarkPATUsed(context.Background(), "tokenId")
if err != nil {
@@ -142,7 +142,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
// these tests only assert groups are parsed from token as per account settings
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
t.Run("JWT groups disabled", func(t *testing.T) {
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
@@ -225,7 +225,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
keyId := "test-key"
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false, nil)
customClaim := func(name string) string {
return fmt.Sprintf("%s/%s", audience, name)

View File

@@ -130,6 +130,10 @@ func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) {
gl.mux.RLock()
defer gl.mux.RUnlock()
if gl.db == nil {
return nil, fmt.Errorf("geolocation database is not available")
}
var record Record
err := gl.db.Lookup(ip, &record)
if err != nil {
@@ -173,8 +177,14 @@ func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, er
func (gl *geolocationImpl) Stop() error {
close(gl.stopCh)
if gl.db != nil {
if err := gl.db.Close(); err != nil {
gl.mux.Lock()
db := gl.db
gl.db = nil
gl.mux.Unlock()
if db != nil {
if err := db.Close(); err != nil {
return err
}
}

View File

@@ -61,7 +61,10 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}

View File

@@ -52,7 +52,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
groupName := r.URL.Query().Get("name")
if groupName != "" {
// Get single group by name
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -118,7 +118,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
return
}
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -71,7 +71,7 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return groups, nil
},
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
}

View File

@@ -0,0 +1,238 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Accounts_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, true},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all accounts", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/accounts", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Account{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
account := got[0]
assert.Equal(t, "test.com", account.Domain)
assert.Equal(t, "private", account.DomainCategory)
assert.Equal(t, true, account.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, 86400, account.Settings.PeerLoginExpiration)
assert.Equal(t, false, account.Settings.RegularUsersViewBlocked)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Accounts_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
trueVal := true
falseVal := false
tt := []struct {
name string
expectedStatus int
requestBody *api.AccountRequest
verifyResponse func(t *testing.T, account *api.Account)
verifyDB func(t *testing.T, account *types.Account)
}{
{
name: "Disable peer login expiration",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: false,
PeerLoginExpiration: 86400,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.Equal(t, false, account.Settings.PeerLoginExpirationEnabled)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, false, dbAccount.Settings.PeerLoginExpirationEnabled)
},
},
{
name: "Update peer login expiration to 48h",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 172800,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.Equal(t, 172800, account.Settings.PeerLoginExpiration)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, 172800*time.Second, dbAccount.Settings.PeerLoginExpiration)
},
},
{
name: "Enable regular users view blocked",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 86400,
RegularUsersViewBlocked: true,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.Equal(t, true, account.Settings.RegularUsersViewBlocked)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, true, dbAccount.Settings.RegularUsersViewBlocked)
},
},
{
name: "Enable groups propagation",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 86400,
GroupsPropagationEnabled: &trueVal,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.NotNil(t, account.Settings.GroupsPropagationEnabled)
assert.Equal(t, true, *account.Settings.GroupsPropagationEnabled)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, true, dbAccount.Settings.GroupsPropagationEnabled)
},
},
{
name: "Enable JWT groups",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 86400,
GroupsPropagationEnabled: &falseVal,
JwtGroupsEnabled: &trueVal,
JwtGroupsClaimName: stringPointer("groups"),
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.NotNil(t, account.Settings.JwtGroupsEnabled)
assert.Equal(t, true, *account.Settings.JwtGroupsEnabled)
assert.NotNil(t, account.Settings.JwtGroupsClaimName)
assert.Equal(t, "groups", *account.Settings.JwtGroupsClaimName)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, true, dbAccount.Settings.JWTGroupsEnabled)
assert.Equal(t, "groups", dbAccount.Settings.JWTGroupsClaimName)
},
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/accounts/{accountId}", "{accountId}", testing_tools.TestAccountId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
got := &api.Account{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, testing_tools.TestAccountId, got.Id)
assert.Equal(t, "test.com", got.Domain)
tc.verifyResponse(t, got)
db := testing_tools.GetDB(t, am.GetStore())
dbAccount := testing_tools.VerifyAccountSettings(t, db)
tc.verifyDB(t, dbAccount)
})
}
}
}
func stringPointer(s string) *string {
return &s
}

View File

@@ -0,0 +1,554 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Nameservers_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all nameservers", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/nameservers", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.NameserverGroup{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "testNSGroup", got[0].Name)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Nameservers_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
nsGroupId string
expectedStatus int
expectGroup bool
}{
{
name: "Get existing nameserver group",
nsGroupId: "testNSGroupId",
expectedStatus: http.StatusOK,
expectGroup: true,
},
{
name: "Get non-existing nameserver group",
nsGroupId: "nonExistingNSGroupId",
expectedStatus: http.StatusNotFound,
expectGroup: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectGroup {
got := &api.NameserverGroup{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "testNSGroupId", got.Id)
assert.Equal(t, "testNSGroup", got.Name)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Nameservers_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
requestBody *api.PostApiDnsNameserversJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, nsGroup *api.NameserverGroup)
}{
{
name: "Create nameserver group with single NS",
requestBody: &api.PostApiDnsNameserversJSONRequestBody{
Name: "newNSGroup",
Description: "a new nameserver group",
Nameservers: []api.Nameserver{
{Ip: "8.8.8.8", NsType: "udp", Port: 53},
},
Groups: []string{testing_tools.TestGroupId},
Primary: false,
Domains: []string{"test.com"},
Enabled: true,
SearchDomainsEnabled: false,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) {
t.Helper()
assert.NotEmpty(t, nsGroup.Id)
assert.Equal(t, "newNSGroup", nsGroup.Name)
assert.Equal(t, 1, len(nsGroup.Nameservers))
assert.Equal(t, false, nsGroup.Primary)
},
},
{
name: "Create primary nameserver group",
requestBody: &api.PostApiDnsNameserversJSONRequestBody{
Name: "primaryNS",
Description: "primary nameserver",
Nameservers: []api.Nameserver{
{Ip: "1.1.1.1", NsType: "udp", Port: 53},
},
Groups: []string{testing_tools.TestGroupId},
Primary: true,
Domains: []string{},
Enabled: true,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) {
t.Helper()
assert.Equal(t, true, nsGroup.Primary)
},
},
{
name: "Create nameserver group with empty groups",
requestBody: &api.PostApiDnsNameserversJSONRequestBody{
Name: "emptyGroupsNS",
Description: "no groups",
Nameservers: []api.Nameserver{
{Ip: "8.8.8.8", NsType: "udp", Port: 53},
},
Groups: []string{},
Primary: true,
Domains: []string{},
Enabled: true,
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/dns/nameservers", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.NameserverGroup{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the created NS group directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbNS := testing_tools.VerifyNSGroupInDB(t, db, got.Id)
assert.Equal(t, got.Name, dbNS.Name)
assert.Equal(t, got.Primary, dbNS.Primary)
assert.Equal(t, len(got.Nameservers), len(dbNS.NameServers))
assert.Equal(t, got.Enabled, dbNS.Enabled)
assert.Equal(t, got.SearchDomainsEnabled, dbNS.SearchDomainsEnabled)
}
})
}
}
}
func Test_Nameservers_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
nsGroupId string
requestBody *api.PutApiDnsNameserversNsgroupIdJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, nsGroup *api.NameserverGroup)
}{
{
name: "Update nameserver group name",
nsGroupId: "testNSGroupId",
requestBody: &api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
Name: "updatedNSGroup",
Description: "updated description",
Nameservers: []api.Nameserver{
{Ip: "1.1.1.1", NsType: "udp", Port: 53},
},
Groups: []string{testing_tools.TestGroupId},
Primary: false,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: false,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) {
t.Helper()
assert.Equal(t, "updatedNSGroup", nsGroup.Name)
assert.Equal(t, "updated description", nsGroup.Description)
},
},
{
name: "Update non-existing nameserver group",
nsGroupId: "nonExistingNSGroupId",
requestBody: &api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
Name: "whatever",
Nameservers: []api.Nameserver{
{Ip: "1.1.1.1", NsType: "udp", Port: 53},
},
Groups: []string{testing_tools.TestGroupId},
Primary: true,
Domains: []string{},
Enabled: true,
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.NameserverGroup{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the updated NS group directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbNS := testing_tools.VerifyNSGroupInDB(t, db, tc.nsGroupId)
assert.Equal(t, "updatedNSGroup", dbNS.Name)
assert.Equal(t, "updated description", dbNS.Description)
assert.Equal(t, false, dbNS.Primary)
assert.Equal(t, true, dbNS.Enabled)
assert.Equal(t, 1, len(dbNS.NameServers))
assert.Equal(t, false, dbNS.SearchDomainsEnabled)
}
})
}
}
}
func Test_Nameservers_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
nsGroupId string
expectedStatus int
}{
{
name: "Delete existing nameserver group",
nsGroupId: "testNSGroupId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing nameserver group",
nsGroupId: "nonExistingNSGroupId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
// Verify deletion in DB for successful deletes by privileged users
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyNSGroupNotInDB(t, db, tc.nsGroupId)
}
})
}
}
}
func Test_DnsSettings_Get(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get DNS settings", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/settings", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := &api.DNSSettings{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.NotNil(t, got.DisabledManagementGroups)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_DnsSettings_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
requestBody *api.PutApiDnsSettingsJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, settings *api.DNSSettings)
expectedDBDisabledMgmtLen int
expectedDBDisabledMgmtItem string
}{
{
name: "Update disabled management groups",
requestBody: &api.PutApiDnsSettingsJSONRequestBody{
DisabledManagementGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, settings *api.DNSSettings) {
t.Helper()
assert.Equal(t, 1, len(settings.DisabledManagementGroups))
assert.Equal(t, testing_tools.TestGroupId, settings.DisabledManagementGroups[0])
},
expectedDBDisabledMgmtLen: 1,
expectedDBDisabledMgmtItem: testing_tools.TestGroupId,
},
{
name: "Update with empty disabled management groups",
requestBody: &api.PutApiDnsSettingsJSONRequestBody{
DisabledManagementGroups: []string{},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, settings *api.DNSSettings) {
t.Helper()
assert.Equal(t, 0, len(settings.DisabledManagementGroups))
},
expectedDBDisabledMgmtLen: 0,
},
{
name: "Update with non-existing group",
requestBody: &api.PutApiDnsSettingsJSONRequestBody{
DisabledManagementGroups: []string{"nonExistingGroupId"},
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/dns/settings", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.DNSSettings{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify DNS settings directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbAccount := testing_tools.VerifyAccountSettings(t, db)
assert.Equal(t, tc.expectedDBDisabledMgmtLen, len(dbAccount.DNSSettings.DisabledManagementGroups))
if tc.expectedDBDisabledMgmtItem != "" {
assert.Contains(t, dbAccount.DNSSettings.DisabledManagementGroups, tc.expectedDBDisabledMgmtItem)
}
}
})
}
}
}

View File

@@ -0,0 +1,105 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Events_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all events", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, false)
// First, perform a mutation to generate an event (create a group as admin)
groupBody, err := json.Marshal(&api.GroupRequest{Name: "eventTestGroup"})
if err != nil {
t.Fatalf("Failed to marshal group request: %v", err)
}
createReq := testing_tools.BuildRequest(t, groupBody, http.MethodPost, "/api/groups", testing_tools.TestAdminId)
createRecorder := httptest.NewRecorder()
apiHandler.ServeHTTP(createRecorder, createReq)
assert.Equal(t, http.StatusOK, createRecorder.Code, "Failed to create group to generate event")
// Now query events
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Event{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.GreaterOrEqual(t, len(got), 1, "Expected at least one event after creating a group")
// Verify the group creation event exists
found := false
for _, event := range got {
if event.ActivityCode == "group.add" {
found = true
assert.Equal(t, testing_tools.TestAdminId, event.InitiatorId)
assert.Equal(t, "Group created", event.Activity)
break
}
}
assert.True(t, found, "Expected to find a group.add event")
})
}
}
func Test_Events_GetAll_Empty(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, true)
if !expectResponse {
return
}
got := []api.Event{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got), "Expected empty events list when no mutations have been performed")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
}

View File

@@ -0,0 +1,382 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Groups_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, true},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all groups", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/groups", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Group{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.GreaterOrEqual(t, len(got), 2)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Groups_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, true},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
groupId string
expectedStatus int
expectGroup bool
}{
{
name: "Get existing group",
groupId: testing_tools.TestGroupId,
expectedStatus: http.StatusOK,
expectGroup: true,
},
{
name: "Get non-existing group",
groupId: "nonExistingGroupId",
expectedStatus: http.StatusNotFound,
expectGroup: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/groups/{groupId}", "{groupId}", tc.groupId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectGroup {
got := &api.Group{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, tc.groupId, got.Id)
assert.Equal(t, "testGroupName", got.Name)
assert.Equal(t, 1, got.PeersCount)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Groups_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
requestBody *api.GroupRequest
expectedStatus int
verifyResponse func(t *testing.T, group *api.Group)
}{
{
name: "Create group with valid name",
requestBody: &api.GroupRequest{
Name: "brandNewGroup",
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, group *api.Group) {
t.Helper()
assert.NotEmpty(t, group.Id)
assert.Equal(t, "brandNewGroup", group.Name)
assert.Equal(t, 0, group.PeersCount)
},
},
{
name: "Create group with peers",
requestBody: &api.GroupRequest{
Name: "groupWithPeers",
Peers: &[]string{testing_tools.TestPeerId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, group *api.Group) {
t.Helper()
assert.NotEmpty(t, group.Id)
assert.Equal(t, "groupWithPeers", group.Name)
assert.Equal(t, 1, group.PeersCount)
},
},
{
name: "Create group with empty name",
requestBody: &api.GroupRequest{
Name: "",
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/groups", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Group{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify group exists in DB
db := testing_tools.GetDB(t, am.GetStore())
dbGroup := testing_tools.VerifyGroupInDB(t, db, got.Id)
assert.Equal(t, tc.requestBody.Name, dbGroup.Name)
}
})
}
}
}
func Test_Groups_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
groupId string
requestBody *api.GroupRequest
expectedStatus int
verifyResponse func(t *testing.T, group *api.Group)
}{
{
name: "Update group name",
groupId: testing_tools.TestGroupId,
requestBody: &api.GroupRequest{
Name: "updatedGroupName",
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, group *api.Group) {
t.Helper()
assert.Equal(t, testing_tools.TestGroupId, group.Id)
assert.Equal(t, "updatedGroupName", group.Name)
},
},
{
name: "Update group peers",
groupId: testing_tools.TestGroupId,
requestBody: &api.GroupRequest{
Name: "testGroupName",
Peers: &[]string{},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, group *api.Group) {
t.Helper()
assert.Equal(t, 0, group.PeersCount)
},
},
{
name: "Update with empty name",
groupId: testing_tools.TestGroupId,
requestBody: &api.GroupRequest{
Name: "",
},
expectedStatus: http.StatusUnprocessableEntity,
},
{
name: "Update non-existing group",
groupId: "nonExistingGroupId",
requestBody: &api.GroupRequest{
Name: "someName",
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/groups/{groupId}", "{groupId}", tc.groupId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Group{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify updated group in DB
db := testing_tools.GetDB(t, am.GetStore())
dbGroup := testing_tools.VerifyGroupInDB(t, db, tc.groupId)
assert.Equal(t, tc.requestBody.Name, dbGroup.Name)
}
})
}
}
}
func Test_Groups_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
groupId string
expectedStatus int
}{
{
name: "Delete existing group not in use",
groupId: testing_tools.NewGroupId,
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing group",
groupId: "nonExistingGroupId",
expectedStatus: http.StatusBadRequest,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/groups.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/groups/{groupId}", "{groupId}", tc.groupId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
_, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if expectResponse && tc.expectedStatus == http.StatusOK {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyGroupNotInDB(t, db, tc.groupId)
}
})
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,605 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
const (
testPeerId2 = "testPeerId2"
)
func Test_Peers_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{
name: "Regular user",
userId: testing_tools.TestUserId,
expectResponse: false,
},
{
name: "Admin user",
userId: testing_tools.TestAdminId,
expectResponse: true,
},
{
name: "Owner user",
userId: testing_tools.TestOwnerId,
expectResponse: true,
},
{
name: "Regular service user",
userId: testing_tools.TestServiceUserId,
expectResponse: true,
},
{
name: "Admin service user",
userId: testing_tools.TestServiceAdminId,
expectResponse: true,
},
{
name: "Blocked user",
userId: testing_tools.BlockedUserId,
expectResponse: false,
},
{
name: "Other user",
userId: testing_tools.OtherUserId,
expectResponse: false,
},
{
name: "Invalid token",
userId: testing_tools.InvalidToken,
expectResponse: false,
},
}
for _, user := range users {
t.Run(user.name+" - Get all peers", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/peers", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
var got []api.PeerBatch
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.GreaterOrEqual(t, len(got), 2, "Expected at least 2 peers")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Peers_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{
name: "Regular user",
userId: testing_tools.TestUserId,
expectResponse: false,
},
{
name: "Admin user",
userId: testing_tools.TestAdminId,
expectResponse: true,
},
{
name: "Owner user",
userId: testing_tools.TestOwnerId,
expectResponse: true,
},
{
name: "Regular service user",
userId: testing_tools.TestServiceUserId,
expectResponse: true,
},
{
name: "Admin service user",
userId: testing_tools.TestServiceAdminId,
expectResponse: true,
},
{
name: "Blocked user",
userId: testing_tools.BlockedUserId,
expectResponse: false,
},
{
name: "Other user",
userId: testing_tools.OtherUserId,
expectResponse: false,
},
{
name: "Invalid token",
userId: testing_tools.InvalidToken,
expectResponse: false,
},
}
tt := []struct {
name string
expectedStatus int
requestType string
requestPath string
requestId string
verifyResponse func(t *testing.T, peer *api.Peer)
}{
{
name: "Get existing peer",
requestType: http.MethodGet,
requestPath: "/api/peers/{peerId}",
requestId: testing_tools.TestPeerId,
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, peer *api.Peer) {
t.Helper()
assert.Equal(t, testing_tools.TestPeerId, peer.Id)
assert.Equal(t, "test-peer-1", peer.Name)
assert.Equal(t, "test-host-1", peer.Hostname)
assert.Equal(t, "Debian GNU/Linux ", peer.Os)
assert.Equal(t, "0.12.0", peer.Version)
assert.Equal(t, false, peer.SshEnabled)
assert.Equal(t, true, peer.LoginExpirationEnabled)
},
},
{
name: "Get second existing peer",
requestType: http.MethodGet,
requestPath: "/api/peers/{peerId}",
requestId: testPeerId2,
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, peer *api.Peer) {
t.Helper()
assert.Equal(t, testPeerId2, peer.Id)
assert.Equal(t, "test-peer-2", peer.Name)
assert.Equal(t, "test-host-2", peer.Hostname)
assert.Equal(t, "Ubuntu ", peer.Os)
assert.Equal(t, true, peer.SshEnabled)
assert.Equal(t, false, peer.LoginExpirationEnabled)
assert.Equal(t, true, peer.Connected)
},
},
{
name: "Get non-existing peer",
requestType: http.MethodGet,
requestPath: "/api/peers/{peerId}",
requestId: "nonExistingPeerId",
expectedStatus: http.StatusNotFound,
verifyResponse: nil,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Peer{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Peers_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{
name: "Regular user",
userId: testing_tools.TestUserId,
expectResponse: false,
},
{
name: "Admin user",
userId: testing_tools.TestAdminId,
expectResponse: true,
},
{
name: "Owner user",
userId: testing_tools.TestOwnerId,
expectResponse: true,
},
{
name: "Regular service user",
userId: testing_tools.TestServiceUserId,
expectResponse: false,
},
{
name: "Admin service user",
userId: testing_tools.TestServiceAdminId,
expectResponse: true,
},
{
name: "Blocked user",
userId: testing_tools.BlockedUserId,
expectResponse: false,
},
{
name: "Other user",
userId: testing_tools.OtherUserId,
expectResponse: false,
},
{
name: "Invalid token",
userId: testing_tools.InvalidToken,
expectResponse: false,
},
}
tt := []struct {
name string
expectedStatus int
requestBody *api.PeerRequest
requestType string
requestPath string
requestId string
verifyResponse func(t *testing.T, peer *api.Peer)
}{
{
name: "Update peer name",
requestType: http.MethodPut,
requestPath: "/api/peers/{peerId}",
requestId: testing_tools.TestPeerId,
requestBody: &api.PeerRequest{
Name: "updated-peer-name",
SshEnabled: false,
LoginExpirationEnabled: true,
InactivityExpirationEnabled: false,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, peer *api.Peer) {
t.Helper()
assert.Equal(t, testing_tools.TestPeerId, peer.Id)
assert.Equal(t, "updated-peer-name", peer.Name)
assert.Equal(t, false, peer.SshEnabled)
assert.Equal(t, true, peer.LoginExpirationEnabled)
},
},
{
name: "Enable SSH on peer",
requestType: http.MethodPut,
requestPath: "/api/peers/{peerId}",
requestId: testing_tools.TestPeerId,
requestBody: &api.PeerRequest{
Name: "test-peer-1",
SshEnabled: true,
LoginExpirationEnabled: true,
InactivityExpirationEnabled: false,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, peer *api.Peer) {
t.Helper()
assert.Equal(t, testing_tools.TestPeerId, peer.Id)
assert.Equal(t, "test-peer-1", peer.Name)
assert.Equal(t, true, peer.SshEnabled)
assert.Equal(t, true, peer.LoginExpirationEnabled)
},
},
{
name: "Disable login expiration on peer",
requestType: http.MethodPut,
requestPath: "/api/peers/{peerId}",
requestId: testing_tools.TestPeerId,
requestBody: &api.PeerRequest{
Name: "test-peer-1",
SshEnabled: false,
LoginExpirationEnabled: false,
InactivityExpirationEnabled: false,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, peer *api.Peer) {
t.Helper()
assert.Equal(t, testing_tools.TestPeerId, peer.Id)
assert.Equal(t, false, peer.LoginExpirationEnabled)
},
},
{
name: "Update non-existing peer",
requestType: http.MethodPut,
requestPath: "/api/peers/{peerId}",
requestId: "nonExistingPeerId",
requestBody: &api.PeerRequest{
Name: "updated-name",
SshEnabled: false,
LoginExpirationEnabled: false,
InactivityExpirationEnabled: false,
},
expectedStatus: http.StatusNotFound,
verifyResponse: nil,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Peer{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify updated peer in DB
db := testing_tools.GetDB(t, am.GetStore())
dbPeer := testing_tools.VerifyPeerInDB(t, db, tc.requestId)
assert.Equal(t, tc.requestBody.Name, dbPeer.Name)
assert.Equal(t, tc.requestBody.SshEnabled, dbPeer.SSHEnabled)
assert.Equal(t, tc.requestBody.LoginExpirationEnabled, dbPeer.LoginExpirationEnabled)
assert.Equal(t, tc.requestBody.InactivityExpirationEnabled, dbPeer.InactivityExpirationEnabled)
}
})
}
}
}
func Test_Peers_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{
name: "Regular user",
userId: testing_tools.TestUserId,
expectResponse: false,
},
{
name: "Admin user",
userId: testing_tools.TestAdminId,
expectResponse: true,
},
{
name: "Owner user",
userId: testing_tools.TestOwnerId,
expectResponse: true,
},
{
name: "Regular service user",
userId: testing_tools.TestServiceUserId,
expectResponse: false,
},
{
name: "Admin service user",
userId: testing_tools.TestServiceAdminId,
expectResponse: true,
},
{
name: "Blocked user",
userId: testing_tools.BlockedUserId,
expectResponse: false,
},
{
name: "Other user",
userId: testing_tools.OtherUserId,
expectResponse: false,
},
{
name: "Invalid token",
userId: testing_tools.InvalidToken,
expectResponse: false,
},
}
tt := []struct {
name string
expectedStatus int
requestType string
requestPath string
requestId string
}{
{
name: "Delete existing peer",
requestType: http.MethodDelete,
requestPath: "/api/peers/{peerId}",
requestId: testPeerId2,
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing peer",
requestType: http.MethodDelete,
requestPath: "/api/peers/{peerId}",
requestId: "nonExistingPeerId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
_, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
// Verify peer is actually deleted in DB
if tc.expectedStatus == http.StatusOK {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyPeerNotInDB(t, db, tc.requestId)
}
})
}
}
}
func Test_Peers_GetAccessiblePeers(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{
name: "Regular user",
userId: testing_tools.TestUserId,
expectResponse: false,
},
{
name: "Admin user",
userId: testing_tools.TestAdminId,
expectResponse: true,
},
{
name: "Owner user",
userId: testing_tools.TestOwnerId,
expectResponse: true,
},
{
name: "Regular service user",
userId: testing_tools.TestServiceUserId,
expectResponse: false,
},
{
name: "Admin service user",
userId: testing_tools.TestServiceAdminId,
expectResponse: true,
},
{
name: "Blocked user",
userId: testing_tools.BlockedUserId,
expectResponse: false,
},
{
name: "Other user",
userId: testing_tools.OtherUserId,
expectResponse: false,
},
{
name: "Invalid token",
userId: testing_tools.InvalidToken,
expectResponse: false,
},
}
tt := []struct {
name string
expectedStatus int
requestType string
requestPath string
requestId string
}{
{
name: "Get accessible peers for existing peer",
requestType: http.MethodGet,
requestPath: "/api/peers/{peerId}/accessible-peers",
requestId: testing_tools.TestPeerId,
expectedStatus: http.StatusOK,
},
{
name: "Get accessible peers for non-existing peer",
requestType: http.MethodGet,
requestPath: "/api/peers/{peerId}/accessible-peers",
requestId: "nonExistingPeerId",
expectedStatus: http.StatusOK,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/peers_integration.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{peerId}", tc.requestId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectedStatus == http.StatusOK {
var got []api.AccessiblePeer
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
// The accessible peers list should be a valid array (may be empty if no policies connect peers)
assert.NotNil(t, got, "Expected accessible peers to be a valid array")
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}

View File

@@ -0,0 +1,488 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Policies_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all policies", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/policies", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Policy{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "testPolicy", got[0].Name)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Policies_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
policyId string
expectedStatus int
expectPolicy bool
}{
{
name: "Get existing policy",
policyId: "testPolicyId",
expectedStatus: http.StatusOK,
expectPolicy: true,
},
{
name: "Get non-existing policy",
policyId: "nonExistingPolicyId",
expectedStatus: http.StatusNotFound,
expectPolicy: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/policies/{policyId}", "{policyId}", tc.policyId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectPolicy {
got := &api.Policy{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.NotNil(t, got.Id)
assert.Equal(t, tc.policyId, *got.Id)
assert.Equal(t, "testPolicy", got.Name)
assert.Equal(t, true, got.Enabled)
assert.GreaterOrEqual(t, len(got.Rules), 1)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Policies_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
srcGroups := []string{testing_tools.TestGroupId}
dstGroups := []string{testing_tools.TestGroupId}
tt := []struct {
name string
requestBody *api.PolicyCreate
expectedStatus int
verifyResponse func(t *testing.T, policy *api.Policy)
}{
{
name: "Create policy with accept rule",
requestBody: &api.PolicyCreate{
Name: "newPolicy",
Enabled: true,
Rules: []api.PolicyRuleUpdate{
{
Name: "allowAll",
Enabled: true,
Action: "accept",
Protocol: "all",
Bidirectional: true,
Sources: &srcGroups,
Destinations: &dstGroups,
},
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, policy *api.Policy) {
t.Helper()
assert.NotNil(t, policy.Id)
assert.Equal(t, "newPolicy", policy.Name)
assert.Equal(t, true, policy.Enabled)
assert.Equal(t, 1, len(policy.Rules))
assert.Equal(t, "allowAll", policy.Rules[0].Name)
},
},
{
name: "Create policy with drop rule",
requestBody: &api.PolicyCreate{
Name: "dropPolicy",
Enabled: true,
Rules: []api.PolicyRuleUpdate{
{
Name: "dropAll",
Enabled: true,
Action: "drop",
Protocol: "all",
Bidirectional: true,
Sources: &srcGroups,
Destinations: &dstGroups,
},
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, policy *api.Policy) {
t.Helper()
assert.Equal(t, "dropPolicy", policy.Name)
},
},
{
name: "Create policy with TCP rule and ports",
requestBody: &api.PolicyCreate{
Name: "tcpPolicy",
Enabled: true,
Rules: []api.PolicyRuleUpdate{
{
Name: "tcpRule",
Enabled: true,
Action: "accept",
Protocol: "tcp",
Bidirectional: true,
Sources: &srcGroups,
Destinations: &dstGroups,
Ports: &[]string{"80", "443"},
},
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, policy *api.Policy) {
t.Helper()
assert.Equal(t, "tcpPolicy", policy.Name)
assert.NotNil(t, policy.Rules[0].Ports)
assert.Equal(t, 2, len(*policy.Rules[0].Ports))
},
},
{
name: "Create policy with empty name",
requestBody: &api.PolicyCreate{
Name: "",
Enabled: true,
Rules: []api.PolicyRuleUpdate{
{
Name: "rule",
Enabled: true,
Action: "accept",
Protocol: "all",
Sources: &srcGroups,
Destinations: &dstGroups,
},
},
},
expectedStatus: http.StatusUnprocessableEntity,
},
{
name: "Create policy with no rules",
requestBody: &api.PolicyCreate{
Name: "noRulesPolicy",
Enabled: true,
Rules: []api.PolicyRuleUpdate{},
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/policies", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Policy{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify policy exists in DB with correct fields
db := testing_tools.GetDB(t, am.GetStore())
dbPolicy := testing_tools.VerifyPolicyInDB(t, db, *got.Id)
assert.Equal(t, tc.requestBody.Name, dbPolicy.Name)
assert.Equal(t, tc.requestBody.Enabled, dbPolicy.Enabled)
assert.Equal(t, len(tc.requestBody.Rules), len(dbPolicy.Rules))
}
})
}
}
}
func Test_Policies_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
srcGroups := []string{testing_tools.TestGroupId}
dstGroups := []string{testing_tools.TestGroupId}
tt := []struct {
name string
policyId string
requestBody *api.PolicyCreate
expectedStatus int
verifyResponse func(t *testing.T, policy *api.Policy)
}{
{
name: "Update policy name",
policyId: "testPolicyId",
requestBody: &api.PolicyCreate{
Name: "updatedPolicy",
Enabled: true,
Rules: []api.PolicyRuleUpdate{
{
Name: "testRule",
Enabled: true,
Action: "accept",
Protocol: "all",
Bidirectional: true,
Sources: &srcGroups,
Destinations: &dstGroups,
},
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, policy *api.Policy) {
t.Helper()
assert.Equal(t, "updatedPolicy", policy.Name)
},
},
{
name: "Update policy enabled state",
policyId: "testPolicyId",
requestBody: &api.PolicyCreate{
Name: "testPolicy",
Enabled: false,
Rules: []api.PolicyRuleUpdate{
{
Name: "testRule",
Enabled: true,
Action: "accept",
Protocol: "all",
Bidirectional: true,
Sources: &srcGroups,
Destinations: &dstGroups,
},
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, policy *api.Policy) {
t.Helper()
assert.Equal(t, false, policy.Enabled)
},
},
{
name: "Update non-existing policy",
policyId: "nonExistingPolicyId",
requestBody: &api.PolicyCreate{
Name: "whatever",
Enabled: true,
Rules: []api.PolicyRuleUpdate{
{
Name: "rule",
Enabled: true,
Action: "accept",
Protocol: "all",
Sources: &srcGroups,
Destinations: &dstGroups,
},
},
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/policies/{policyId}", "{policyId}", tc.policyId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Policy{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify updated policy in DB
db := testing_tools.GetDB(t, am.GetStore())
dbPolicy := testing_tools.VerifyPolicyInDB(t, db, tc.policyId)
assert.Equal(t, tc.requestBody.Name, dbPolicy.Name)
assert.Equal(t, tc.requestBody.Enabled, dbPolicy.Enabled)
}
})
}
}
}
func Test_Policies_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
policyId string
expectedStatus int
}{
{
name: "Delete existing policy",
policyId: "testPolicyId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing policy",
policyId: "nonExistingPolicyId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/policies.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/policies/{policyId}", "{policyId}", tc.policyId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
_, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if expectResponse && tc.expectedStatus == http.StatusOK {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyPolicyNotInDB(t, db, tc.policyId)
}
})
}
}
}

View File

@@ -0,0 +1,455 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Routes_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all routes", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/routes", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Route{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 2, len(got))
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Routes_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
routeId string
expectedStatus int
expectRoute bool
}{
{
name: "Get existing route",
routeId: "testRouteId",
expectedStatus: http.StatusOK,
expectRoute: true,
},
{
name: "Get non-existing route",
routeId: "nonExistingRouteId",
expectedStatus: http.StatusNotFound,
expectRoute: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/routes/{routeId}", "{routeId}", tc.routeId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectRoute {
got := &api.Route{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, tc.routeId, got.Id)
assert.Equal(t, "Test Network Route", got.Description)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Routes_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
networkCIDR := "10.10.0.0/24"
peerID := testing_tools.TestPeerId
peerGroups := []string{"peerGroupId"}
tt := []struct {
name string
requestBody *api.RouteRequest
expectedStatus int
verifyResponse func(t *testing.T, route *api.Route)
}{
{
name: "Create network route with peer",
requestBody: &api.RouteRequest{
Description: "New network route",
Network: &networkCIDR,
Peer: &peerID,
NetworkId: "newNet",
Metric: 100,
Masquerade: true,
Enabled: true,
Groups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, route *api.Route) {
t.Helper()
assert.NotEmpty(t, route.Id)
assert.Equal(t, "New network route", route.Description)
assert.Equal(t, 100, route.Metric)
assert.Equal(t, true, route.Masquerade)
assert.Equal(t, true, route.Enabled)
},
},
{
name: "Create network route with peer groups",
requestBody: &api.RouteRequest{
Description: "Route with peer groups",
Network: &networkCIDR,
PeerGroups: &peerGroups,
NetworkId: "peerGroupNet",
Metric: 150,
Masquerade: false,
Enabled: true,
Groups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, route *api.Route) {
t.Helper()
assert.NotEmpty(t, route.Id)
assert.Equal(t, "Route with peer groups", route.Description)
},
},
{
name: "Create route with empty network_id",
requestBody: &api.RouteRequest{
Description: "Empty net id",
Network: &networkCIDR,
Peer: &peerID,
NetworkId: "",
Metric: 100,
Enabled: true,
Groups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusUnprocessableEntity,
},
{
name: "Create route with metric 0",
requestBody: &api.RouteRequest{
Description: "Zero metric",
Network: &networkCIDR,
Peer: &peerID,
NetworkId: "zeroMetric",
Metric: 0,
Enabled: true,
Groups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusUnprocessableEntity,
},
{
name: "Create route with metric 10000",
requestBody: &api.RouteRequest{
Description: "High metric",
Network: &networkCIDR,
Peer: &peerID,
NetworkId: "highMetric",
Metric: 10000,
Enabled: true,
Groups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/routes", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Route{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify route exists in DB with correct fields
db := testing_tools.GetDB(t, am.GetStore())
dbRoute := testing_tools.VerifyRouteInDB(t, db, route.ID(got.Id))
assert.Equal(t, tc.requestBody.Description, dbRoute.Description)
assert.Equal(t, tc.requestBody.Metric, dbRoute.Metric)
assert.Equal(t, tc.requestBody.Masquerade, dbRoute.Masquerade)
assert.Equal(t, tc.requestBody.Enabled, dbRoute.Enabled)
assert.Equal(t, route.NetID(tc.requestBody.NetworkId), dbRoute.NetID)
}
})
}
}
}
func Test_Routes_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
networkCIDR := "10.0.0.0/24"
peerID := testing_tools.TestPeerId
tt := []struct {
name string
routeId string
requestBody *api.RouteRequest
expectedStatus int
verifyResponse func(t *testing.T, route *api.Route)
}{
{
name: "Update route description",
routeId: "testRouteId",
requestBody: &api.RouteRequest{
Description: "Updated description",
Network: &networkCIDR,
Peer: &peerID,
NetworkId: "testNet",
Metric: 100,
Masquerade: true,
Enabled: true,
Groups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, route *api.Route) {
t.Helper()
assert.Equal(t, "testRouteId", route.Id)
assert.Equal(t, "Updated description", route.Description)
},
},
{
name: "Update route metric",
routeId: "testRouteId",
requestBody: &api.RouteRequest{
Description: "Test Network Route",
Network: &networkCIDR,
Peer: &peerID,
NetworkId: "testNet",
Metric: 500,
Masquerade: true,
Enabled: true,
Groups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, route *api.Route) {
t.Helper()
assert.Equal(t, 500, route.Metric)
},
},
{
name: "Update non-existing route",
routeId: "nonExistingRouteId",
requestBody: &api.RouteRequest{
Description: "whatever",
Network: &networkCIDR,
Peer: &peerID,
NetworkId: "testNet",
Metric: 100,
Enabled: true,
Groups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/routes/{routeId}", "{routeId}", tc.routeId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Route{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify updated route in DB
db := testing_tools.GetDB(t, am.GetStore())
dbRoute := testing_tools.VerifyRouteInDB(t, db, route.ID(got.Id))
assert.Equal(t, tc.requestBody.Description, dbRoute.Description)
assert.Equal(t, tc.requestBody.Metric, dbRoute.Metric)
assert.Equal(t, tc.requestBody.Masquerade, dbRoute.Masquerade)
assert.Equal(t, tc.requestBody.Enabled, dbRoute.Enabled)
}
})
}
}
}
func Test_Routes_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
routeId string
expectedStatus int
}{
{
name: "Delete existing route",
routeId: "testRouteId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing route",
routeId: "nonExistingRouteId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/routes.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/routes/{routeId}", "{routeId}", tc.routeId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
// Verify route was deleted from DB for successful deletes
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyRouteNotInDB(t, db, route.ID(tc.routeId))
}
})
}
}
}

View File

@@ -3,7 +3,6 @@
package integration
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
@@ -14,7 +13,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -254,7 +252,7 @@ func Test_SetupKeys_Create(t *testing.T) {
expectedResponse: nil,
},
{
name: "Create Setup Key",
name: "Create Setup Key with nil AutoGroups",
requestType: http.MethodPost,
requestPath: "/api/setup-keys",
requestBody: &api.CreateSetupKeyRequest{
@@ -308,14 +306,15 @@ func Test_SetupKeys_Create(t *testing.T) {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
gotID := got.Id
validateCreatedKey(t, tc.expectedResponse, got)
key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id)
if err != nil {
return
}
validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key))
// Verify setup key exists in DB via gorm
db := testing_tools.GetDB(t, am.GetStore())
dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID)
assert.Equal(t, tc.expectedResponse.Name, dbKey.Name)
assert.Equal(t, tc.expectedResponse.Revoked, dbKey.Revoked)
assert.Equal(t, tc.expectedResponse.UsageLimit, dbKey.UsageLimit)
select {
case <-done:
@@ -571,7 +570,7 @@ func Test_SetupKeys_Update(t *testing.T) {
for _, tc := range tt {
for _, user := range users {
t.Run(tc.name, func(t *testing.T) {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
body, err := json.Marshal(tc.requestBody)
@@ -594,14 +593,16 @@ func Test_SetupKeys_Update(t *testing.T) {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
gotID := got.Id
gotRevoked := got.Revoked
gotUsageLimit := got.UsageLimit
validateCreatedKey(t, tc.expectedResponse, got)
key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id)
if err != nil {
return
}
validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key))
// Verify updated setup key in DB via gorm
db := testing_tools.GetDB(t, am.GetStore())
dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID)
assert.Equal(t, gotRevoked, dbKey.Revoked)
assert.Equal(t, gotUsageLimit, dbKey.UsageLimit)
select {
case <-done:
@@ -759,8 +760,8 @@ func Test_SetupKeys_Get(t *testing.T) {
apiHandler.ServeHTTP(recorder, req)
content, expectRespnose := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectRespnose {
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
got := &api.SetupKey{}
@@ -768,14 +769,16 @@ func Test_SetupKeys_Get(t *testing.T) {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
gotID := got.Id
gotName := got.Name
gotRevoked := got.Revoked
validateCreatedKey(t, tc.expectedResponse, got)
key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id)
if err != nil {
return
}
validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key))
// Verify setup key in DB via gorm
db := testing_tools.GetDB(t, am.GetStore())
dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID)
assert.Equal(t, gotName, dbKey.Name)
assert.Equal(t, gotRevoked, dbKey.Revoked)
select {
case <-done:
@@ -928,15 +931,17 @@ func Test_SetupKeys_GetAll(t *testing.T) {
return tc.expectedResponse[i].UsageLimit < tc.expectedResponse[j].UsageLimit
})
db := testing_tools.GetDB(t, am.GetStore())
for i := range tc.expectedResponse {
gotID := got[i].Id
gotName := got[i].Name
gotRevoked := got[i].Revoked
validateCreatedKey(t, tc.expectedResponse[i], &got[i])
key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got[i].Id)
if err != nil {
return
}
validateCreatedKey(t, tc.expectedResponse[i], setup_keys.ToResponseBody(key))
// Verify each setup key in DB via gorm
dbKey := testing_tools.VerifySetupKeyInDB(t, db, gotID)
assert.Equal(t, gotName, dbKey.Name)
assert.Equal(t, gotRevoked, dbKey.Revoked)
}
select {
@@ -1104,8 +1109,9 @@ func Test_SetupKeys_Delete(t *testing.T) {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
_, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id)
assert.Errorf(t, err, "Expected error when trying to get deleted key")
// Verify setup key deleted from DB via gorm
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifySetupKeyNotInDB(t, db, got.Id)
select {
case <-done:
@@ -1120,7 +1126,7 @@ func Test_SetupKeys_Delete(t *testing.T) {
func validateCreatedKey(t *testing.T, expectedKey *api.SetupKey, got *api.SetupKey) {
t.Helper()
if got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(testing_tools.ExpiresIn*time.Second)) ||
if (got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(testing_tools.ExpiresIn*time.Second))) ||
got.Expires.After(time.Date(2300, 01, 01, 0, 0, 0, 0, time.Local)) ||
got.Expires.Before(time.Date(1950, 01, 01, 0, 0, 0, 0, time.Local)) {
got.Expires = time.Time{}

View File

@@ -0,0 +1,701 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Users_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, true},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, true},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all users", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.User{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.GreaterOrEqual(t, len(got), 1)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Users_GetAll_ServiceUsers(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all service users", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users?service_user=true", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.User{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
for _, u := range got {
assert.NotNil(t, u.IsServiceUser)
assert.Equal(t, true, *u.IsServiceUser)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Users_Create_ServiceUser(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
requestBody *api.UserCreateRequest
expectedStatus int
verifyResponse func(t *testing.T, user *api.User)
}{
{
name: "Create service user with admin role",
requestBody: &api.UserCreateRequest{
Role: "admin",
IsServiceUser: true,
AutoGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, user *api.User) {
t.Helper()
assert.NotEmpty(t, user.Id)
assert.Equal(t, "admin", user.Role)
assert.NotNil(t, user.IsServiceUser)
assert.Equal(t, true, *user.IsServiceUser)
},
},
{
name: "Create service user with user role",
requestBody: &api.UserCreateRequest{
Role: "user",
IsServiceUser: true,
AutoGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, user *api.User) {
t.Helper()
assert.NotEmpty(t, user.Id)
assert.Equal(t, "user", user.Role)
},
},
{
name: "Create service user with empty auto_groups",
requestBody: &api.UserCreateRequest{
Role: "admin",
IsServiceUser: true,
AutoGroups: []string{},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, user *api.User) {
t.Helper()
assert.NotEmpty(t, user.Id)
},
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/users", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.User{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify user in DB
db := testing_tools.GetDB(t, am.GetStore())
dbUser := testing_tools.VerifyUserInDB(t, db, got.Id)
assert.True(t, dbUser.IsServiceUser)
assert.Equal(t, string(dbUser.Role), string(tc.requestBody.Role))
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Users_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
targetUserId string
requestBody *api.UserRequest
expectedStatus int
verifyResponse func(t *testing.T, user *api.User)
}{
{
name: "Update user role to admin",
targetUserId: testing_tools.TestUserId,
requestBody: &api.UserRequest{
Role: "admin",
AutoGroups: []string{},
IsBlocked: false,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, user *api.User) {
t.Helper()
assert.Equal(t, "admin", user.Role)
},
},
{
name: "Update user auto_groups",
targetUserId: testing_tools.TestUserId,
requestBody: &api.UserRequest{
Role: "user",
AutoGroups: []string{testing_tools.TestGroupId},
IsBlocked: false,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, user *api.User) {
t.Helper()
assert.Equal(t, 1, len(user.AutoGroups))
},
},
{
name: "Block user",
targetUserId: testing_tools.TestUserId,
requestBody: &api.UserRequest{
Role: "user",
AutoGroups: []string{},
IsBlocked: true,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, user *api.User) {
t.Helper()
assert.Equal(t, true, user.IsBlocked)
},
},
{
name: "Update non-existing user",
targetUserId: "nonExistingUserId",
requestBody: &api.UserRequest{
Role: "user",
AutoGroups: []string{},
IsBlocked: false,
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/users/{userId}", "{userId}", tc.targetUserId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.User{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify updated fields in DB
if tc.expectedStatus == http.StatusOK {
db := testing_tools.GetDB(t, am.GetStore())
dbUser := testing_tools.VerifyUserInDB(t, db, tc.targetUserId)
assert.Equal(t, string(dbUser.Role), string(tc.requestBody.Role))
assert.Equal(t, dbUser.Blocked, tc.requestBody.IsBlocked)
assert.ElementsMatch(t, dbUser.AutoGroups, tc.requestBody.AutoGroups)
}
}
})
}
}
}
func Test_Users_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
targetUserId string
expectedStatus int
}{
{
name: "Delete existing service user",
targetUserId: "deletableServiceUserId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing user",
targetUserId: "nonExistingUserId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/{userId}", "{userId}", tc.targetUserId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
_, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
// Verify user deleted from DB for successful deletes
if expectResponse && tc.expectedStatus == http.StatusOK {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyUserNotInDB(t, db, tc.targetUserId)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_PATs_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all PATs for service user", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/users/{userId}/tokens", "{userId}", testing_tools.TestServiceUserId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.PersonalAccessToken{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "serviceToken", got[0].Name)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_PATs_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
tokenId string
expectedStatus int
expectToken bool
}{
{
name: "Get existing PAT",
tokenId: "serviceTokenId",
expectedStatus: http.StatusOK,
expectToken: true,
},
{
name: "Get non-existing PAT",
tokenId: "nonExistingTokenId",
expectedStatus: http.StatusNotFound,
expectToken: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true)
path := strings.Replace("/api/users/{userId}/tokens/{tokenId}", "{userId}", testing_tools.TestServiceUserId, 1)
path = strings.Replace(path, "{tokenId}", tc.tokenId, 1)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectToken {
got := &api.PersonalAccessToken{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "serviceTokenId", got.Id)
assert.Equal(t, "serviceToken", got.Name)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_PATs_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
targetUserId string
requestBody *api.PersonalAccessTokenRequest
expectedStatus int
verifyResponse func(t *testing.T, pat *api.PersonalAccessTokenGenerated)
}{
{
name: "Create PAT with 30 day expiry",
targetUserId: testing_tools.TestServiceUserId,
requestBody: &api.PersonalAccessTokenRequest{
Name: "newPAT",
ExpiresIn: 30,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, pat *api.PersonalAccessTokenGenerated) {
t.Helper()
assert.NotEmpty(t, pat.PlainToken)
assert.Equal(t, "newPAT", pat.PersonalAccessToken.Name)
},
},
{
name: "Create PAT with 365 day expiry",
targetUserId: testing_tools.TestServiceUserId,
requestBody: &api.PersonalAccessTokenRequest{
Name: "longPAT",
ExpiresIn: 365,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, pat *api.PersonalAccessTokenGenerated) {
t.Helper()
assert.NotEmpty(t, pat.PlainToken)
assert.Equal(t, "longPAT", pat.PersonalAccessToken.Name)
},
},
{
name: "Create PAT with empty name",
targetUserId: testing_tools.TestServiceUserId,
requestBody: &api.PersonalAccessTokenRequest{
Name: "",
ExpiresIn: 30,
},
expectedStatus: http.StatusUnprocessableEntity,
},
{
name: "Create PAT with 0 day expiry",
targetUserId: testing_tools.TestServiceUserId,
requestBody: &api.PersonalAccessTokenRequest{
Name: "zeroPAT",
ExpiresIn: 0,
},
expectedStatus: http.StatusUnprocessableEntity,
},
{
name: "Create PAT with expiry over 365 days",
targetUserId: testing_tools.TestServiceUserId,
requestBody: &api.PersonalAccessTokenRequest{
Name: "tooLongPAT",
ExpiresIn: 400,
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, strings.Replace("/api/users/{userId}/tokens", "{userId}", tc.targetUserId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.PersonalAccessTokenGenerated{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify PAT in DB
db := testing_tools.GetDB(t, am.GetStore())
dbPAT := testing_tools.VerifyPATInDB(t, db, got.PersonalAccessToken.Id)
assert.Equal(t, tc.requestBody.Name, dbPAT.Name)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_PATs_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
tokenId string
expectedStatus int
}{
{
name: "Delete existing PAT",
tokenId: "serviceTokenId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing PAT",
tokenId: "nonExistingTokenId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true)
path := strings.Replace("/api/users/{userId}/tokens/{tokenId}", "{userId}", testing_tools.TestServiceUserId, 1)
path = strings.Replace(path, "{tokenId}", tc.tokenId, 1)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
_, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
// Verify PAT deleted from DB for successful deletes
if expectResponse && tc.expectedStatus == http.StatusOK {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyPATNotInDB(t, db, tc.tokenId)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}

View File

@@ -0,0 +1,18 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);

View File

@@ -0,0 +1,21 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO name_server_groups VALUES('testNSGroupId','testAccountId','testNSGroup','test nameserver group','[{"IP":"1.1.1.1","NSType":1,"Port":53}]','["testGroupId"]',0,'["example.com"]',1,0);

View File

@@ -0,0 +1,18 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);

View File

@@ -0,0 +1,19 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('allGroupId','testAccountId','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);

View File

@@ -0,0 +1,25 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,`enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_routers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`domain` text,`prefix` text,`enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'testServiceUser','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'testServiceAdmin','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO networks VALUES('testNetworkId','testAccountId','testNetwork','test network description');
INSERT INTO network_routers VALUES('testRouterId','testNetworkId','testAccountId','testPeerId','[]',1,100,1);
INSERT INTO network_resources VALUES('testResourceId','testNetworkId','testAccountId','testResource','test resource description','host','','"3.3.3.3/32"',1);

View File

@@ -0,0 +1,20 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId","testPeerId2"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','test-host-1','linux','Linux','','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'test-peer-1','test-peer-1','2023-03-02 09:21:02.189035775+01:00',0,0,0,'testUserId','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('testPeerId2','testAccountId','6rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYBg=','82546A29-6BC8-4311-BCFC-9CDBF33F1A49','"100.64.114.32"','test-host-2','linux','Linux','','unknown','Ubuntu','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'test-peer-2','test-peer-2','2023-03-02 09:21:02.189035775+01:00',1,0,0,'testAdminId','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);

View File

@@ -0,0 +1,23 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`protocol` text,`bidirectional` numeric,`sources` text,`destinations` text,`source_resource` text,`destination_resource` text,`ports` text,`port_ranges` text,`authorized_groups` text,`authorized_user` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules_g` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO policies VALUES('testPolicyId','testAccountId','testPolicy','test policy description',1,NULL);
INSERT INTO policy_rules VALUES('testRuleId','testPolicyId','testRule','test rule',1,'accept','all',1,'["testGroupId"]','["testGroupId"]',NULL,NULL,NULL,NULL,NULL,'');

View File

@@ -0,0 +1,23 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,`skip_auto_apply` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('peerGroupId','testAccountId','peerGroupName','api','["testPeerId"]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO routes VALUES('testRouteId','testAccountId','"10.0.0.0/24"',NULL,0,'testNet','Test Network Route','testPeerId',NULL,1,1,100,1,'["testGroupId"]',NULL,0);
INSERT INTO routes VALUES('testDomainRouteId','testAccountId','"0.0.0.0/0"','["example.com"]',0,'testDomainNet','Test Domain Route','','["peerGroupId"]',3,1,200,1,'["testGroupId"]',NULL,0);

View File

@@ -0,0 +1,24 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime DEFAULT NULL,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'testServiceUser','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'testServiceAdmin','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('deletableServiceUserId','testAccountId','user',1,0,'deletableServiceUser','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO personal_access_tokens VALUES('testTokenId','testUserId','testToken','hashedTokenValue123','2325-10-02 16:01:38.000000000+00:00','testUserId','2024-10-02 16:01:38.000000000+00:00',NULL);
INSERT INTO personal_access_tokens VALUES('serviceTokenId','testServiceUserId','serviceToken','hashedServiceTokenValue123','2325-10-02 16:01:38.000000000+00:00','testAdminId','2024-10-02 16:01:38.000000000+00:00',NULL);

View File

@@ -114,13 +114,12 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err)
}
domainManager.SetClusterCapabilities(serviceProxyController)
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager)
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager)
proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
authManagerMock := &serverauth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
@@ -128,14 +127,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
GetPATInfoFunc: authManager.GetPATInfo,
}
networksManagerMock := networks.NewManagerMock()
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
groupsManager := groups.NewManager(store, permissionsManager, am)
routersManager := routers.NewManager(store, permissionsManager, am)
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, am, serviceManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, am)
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -167,6 +166,111 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_m
}
}
// PeerShouldReceiveAnyUpdate waits for a peer update message and returns it.
// Fails the test if no update is received within timeout.
func PeerShouldReceiveAnyUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) *network_map.UpdateMessage {
t.Helper()
select {
case msg := <-updateMessage:
if msg == nil {
t.Errorf("Received nil update message, expected valid message")
}
return msg
case <-time.After(500 * time.Millisecond):
t.Errorf("Timed out waiting for update message")
return nil
}
}
// PeerShouldNotReceiveAnyUpdate verifies no peer update message is received.
func PeerShouldNotReceiveAnyUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
peerShouldNotReceiveUpdate(t, updateMessage)
}
// BuildApiBlackBoxWithDBStateAndPeerChannel creates the API handler and returns
// the peer update channel directly so tests can verify updates inline.
func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile string) (http.Handler, account.Manager, <-chan *network_map.UpdateMessage) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
if err != nil {
t.Fatalf("Failed to create test store: %v", err)
}
t.Cleanup(cleanup)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
t.Fatalf("Failed to create metrics: %v", err)
}
peersUpdateManager := update_channel.NewPeersUpdateManager(nil)
updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId)
geoMock := &geolocation.Mock{}
validatorMock := server.MockIntegratedValidator{}
proxyController := integrations.NewController(store)
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{})
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
ctx := context.Background()
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{})
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create PKCE verifier store: %v", err)
}
noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err)
}
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager)
proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
authManagerMock := &serverauth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
MarkPATUsedFunc: authManager.MarkPATUsed,
GetPATInfoFunc: authManager.GetPATInfo,
}
groupsManager := groups.NewManager(store, permissionsManager, am)
routersManager := routers.NewManager(store, permissionsManager, am)
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, am, serviceManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, am)
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
return apiHandler, am, updMsg
}
func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) {
userAuth := auth.UserAuth{}

View File

@@ -0,0 +1,222 @@
package testing_tools
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// GetDB extracts the *gorm.DB from a store.Store (must be *SqlStore).
func GetDB(t *testing.T, s store.Store) *gorm.DB {
t.Helper()
sqlStore, ok := s.(*store.SqlStore)
require.True(t, ok, "Store is not a *SqlStore, cannot get gorm.DB")
return sqlStore.GetDB()
}
// VerifyGroupInDB reads a group directly from the DB and returns it.
func VerifyGroupInDB(t *testing.T, db *gorm.DB, groupID string) *types.Group {
t.Helper()
var group types.Group
err := db.Where("id = ? AND account_id = ?", groupID, TestAccountId).First(&group).Error
require.NoError(t, err, "Expected group %s to exist in DB", groupID)
return &group
}
// VerifyGroupNotInDB verifies that a group does not exist in the DB.
func VerifyGroupNotInDB(t *testing.T, db *gorm.DB, groupID string) {
t.Helper()
var count int64
db.Model(&types.Group{}).Where("id = ? AND account_id = ?", groupID, TestAccountId).Count(&count)
assert.Equal(t, int64(0), count, "Expected group %s to NOT exist in DB", groupID)
}
// VerifyPolicyInDB reads a policy directly from the DB and returns it.
func VerifyPolicyInDB(t *testing.T, db *gorm.DB, policyID string) *types.Policy {
t.Helper()
var policy types.Policy
err := db.Preload("Rules").Where("id = ? AND account_id = ?", policyID, TestAccountId).First(&policy).Error
require.NoError(t, err, "Expected policy %s to exist in DB", policyID)
return &policy
}
// VerifyPolicyNotInDB verifies that a policy does not exist in the DB.
func VerifyPolicyNotInDB(t *testing.T, db *gorm.DB, policyID string) {
t.Helper()
var count int64
db.Model(&types.Policy{}).Where("id = ? AND account_id = ?", policyID, TestAccountId).Count(&count)
assert.Equal(t, int64(0), count, "Expected policy %s to NOT exist in DB", policyID)
}
// VerifyRouteInDB reads a route directly from the DB and returns it.
func VerifyRouteInDB(t *testing.T, db *gorm.DB, routeID route.ID) *route.Route {
t.Helper()
var r route.Route
err := db.Where("id = ? AND account_id = ?", routeID, TestAccountId).First(&r).Error
require.NoError(t, err, "Expected route %s to exist in DB", routeID)
return &r
}
// VerifyRouteNotInDB verifies that a route does not exist in the DB.
func VerifyRouteNotInDB(t *testing.T, db *gorm.DB, routeID route.ID) {
t.Helper()
var count int64
db.Model(&route.Route{}).Where("id = ? AND account_id = ?", routeID, TestAccountId).Count(&count)
assert.Equal(t, int64(0), count, "Expected route %s to NOT exist in DB", routeID)
}
// VerifyNSGroupInDB reads a nameserver group directly from the DB and returns it.
func VerifyNSGroupInDB(t *testing.T, db *gorm.DB, nsGroupID string) *nbdns.NameServerGroup {
t.Helper()
var nsGroup nbdns.NameServerGroup
err := db.Where("id = ? AND account_id = ?", nsGroupID, TestAccountId).First(&nsGroup).Error
require.NoError(t, err, "Expected NS group %s to exist in DB", nsGroupID)
return &nsGroup
}
// VerifyNSGroupNotInDB verifies that a nameserver group does not exist in the DB.
func VerifyNSGroupNotInDB(t *testing.T, db *gorm.DB, nsGroupID string) {
t.Helper()
var count int64
db.Model(&nbdns.NameServerGroup{}).Where("id = ? AND account_id = ?", nsGroupID, TestAccountId).Count(&count)
assert.Equal(t, int64(0), count, "Expected NS group %s to NOT exist in DB", nsGroupID)
}
// VerifyPeerInDB reads a peer directly from the DB and returns it.
func VerifyPeerInDB(t *testing.T, db *gorm.DB, peerID string) *nbpeer.Peer {
t.Helper()
var peer nbpeer.Peer
err := db.Where("id = ? AND account_id = ?", peerID, TestAccountId).First(&peer).Error
require.NoError(t, err, "Expected peer %s to exist in DB", peerID)
return &peer
}
// VerifyPeerNotInDB verifies that a peer does not exist in the DB.
func VerifyPeerNotInDB(t *testing.T, db *gorm.DB, peerID string) {
t.Helper()
var count int64
db.Model(&nbpeer.Peer{}).Where("id = ? AND account_id = ?", peerID, TestAccountId).Count(&count)
assert.Equal(t, int64(0), count, "Expected peer %s to NOT exist in DB", peerID)
}
// VerifySetupKeyInDB reads a setup key directly from the DB and returns it.
func VerifySetupKeyInDB(t *testing.T, db *gorm.DB, keyID string) *types.SetupKey {
t.Helper()
var key types.SetupKey
err := db.Where("id = ? AND account_id = ?", keyID, TestAccountId).First(&key).Error
require.NoError(t, err, "Expected setup key %s to exist in DB", keyID)
return &key
}
// VerifySetupKeyNotInDB verifies that a setup key does not exist in the DB.
func VerifySetupKeyNotInDB(t *testing.T, db *gorm.DB, keyID string) {
t.Helper()
var count int64
db.Model(&types.SetupKey{}).Where("id = ? AND account_id = ?", keyID, TestAccountId).Count(&count)
assert.Equal(t, int64(0), count, "Expected setup key %s to NOT exist in DB", keyID)
}
// VerifyUserInDB reads a user directly from the DB and returns it.
func VerifyUserInDB(t *testing.T, db *gorm.DB, userID string) *types.User {
t.Helper()
var user types.User
err := db.Where("id = ? AND account_id = ?", userID, TestAccountId).First(&user).Error
require.NoError(t, err, "Expected user %s to exist in DB", userID)
return &user
}
// VerifyUserNotInDB verifies that a user does not exist in the DB.
func VerifyUserNotInDB(t *testing.T, db *gorm.DB, userID string) {
t.Helper()
var count int64
db.Model(&types.User{}).Where("id = ? AND account_id = ?", userID, TestAccountId).Count(&count)
assert.Equal(t, int64(0), count, "Expected user %s to NOT exist in DB", userID)
}
// VerifyPATInDB reads a PAT directly from the DB and returns it.
func VerifyPATInDB(t *testing.T, db *gorm.DB, tokenID string) *types.PersonalAccessToken {
t.Helper()
var pat types.PersonalAccessToken
err := db.Where("id = ?", tokenID).First(&pat).Error
require.NoError(t, err, "Expected PAT %s to exist in DB", tokenID)
return &pat
}
// VerifyPATNotInDB verifies that a PAT does not exist in the DB.
func VerifyPATNotInDB(t *testing.T, db *gorm.DB, tokenID string) {
t.Helper()
var count int64
db.Model(&types.PersonalAccessToken{}).Where("id = ?", tokenID).Count(&count)
assert.Equal(t, int64(0), count, "Expected PAT %s to NOT exist in DB", tokenID)
}
// VerifyAccountSettings reads the account and returns its settings from the DB.
func VerifyAccountSettings(t *testing.T, db *gorm.DB) *types.Account {
t.Helper()
var account types.Account
err := db.Where("id = ?", TestAccountId).First(&account).Error
require.NoError(t, err, "Expected account %s to exist in DB", TestAccountId)
return &account
}
// VerifyNetworkInDB reads a network directly from the store and returns it.
func VerifyNetworkInDB(t *testing.T, db *gorm.DB, networkID string) *networkTypes.Network {
t.Helper()
var network networkTypes.Network
err := db.Where("id = ? AND account_id = ?", networkID, TestAccountId).First(&network).Error
require.NoError(t, err, "Expected network %s to exist in DB", networkID)
return &network
}
// VerifyNetworkNotInDB verifies that a network does not exist in the DB.
func VerifyNetworkNotInDB(t *testing.T, db *gorm.DB, networkID string) {
t.Helper()
var count int64
db.Model(&networkTypes.Network{}).Where("id = ? AND account_id = ?", networkID, TestAccountId).Count(&count)
assert.Equal(t, int64(0), count, "Expected network %s to NOT exist in DB", networkID)
}
// VerifyNetworkResourceInDB reads a network resource directly from the DB and returns it.
func VerifyNetworkResourceInDB(t *testing.T, db *gorm.DB, resourceID string) *resourceTypes.NetworkResource {
t.Helper()
var resource resourceTypes.NetworkResource
err := db.Where("id = ? AND account_id = ?", resourceID, TestAccountId).First(&resource).Error
require.NoError(t, err, "Expected network resource %s to exist in DB", resourceID)
return &resource
}
// VerifyNetworkResourceNotInDB verifies that a network resource does not exist in the DB.
func VerifyNetworkResourceNotInDB(t *testing.T, db *gorm.DB, resourceID string) {
t.Helper()
var count int64
db.Model(&resourceTypes.NetworkResource{}).Where("id = ? AND account_id = ?", resourceID, TestAccountId).Count(&count)
assert.Equal(t, int64(0), count, "Expected network resource %s to NOT exist in DB", resourceID)
}
// VerifyNetworkRouterInDB reads a network router directly from the DB and returns it.
func VerifyNetworkRouterInDB(t *testing.T, db *gorm.DB, routerID string) *routerTypes.NetworkRouter {
t.Helper()
var router routerTypes.NetworkRouter
err := db.Where("id = ? AND account_id = ?", routerID, TestAccountId).First(&router).Error
require.NoError(t, err, "Expected network router %s to exist in DB", routerID)
return &router
}
// VerifyNetworkRouterNotInDB verifies that a network router does not exist in the DB.
func VerifyNetworkRouterNotInDB(t *testing.T, db *gorm.DB, routerID string) {
t.Helper()
var count int64
db.Model(&routerTypes.NetworkRouter{}).Where("id = ? AND account_id = ?", routerID, TestAccountId).Count(&count)
assert.Equal(t, int64(0), count, "Expected network router %s to NOT exist in DB", routerID)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/telemetry"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
const (
@@ -48,6 +49,8 @@ type EmbeddedIdPConfig struct {
// Existing local users are preserved and will be able to login again if re-enabled.
// Cannot be enabled if no external identity provider connectors are configured.
LocalAuthDisabled bool
// StaticConnectors are additional connectors to seed during initialization
StaticConnectors []dex.Connector
}
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
@@ -157,6 +160,7 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
RedirectURIs: cliRedirectURIs,
},
},
StaticConnectors: c.StaticConnectors,
}
// Add owner user if provided
@@ -193,6 +197,9 @@ type OAuthConfigProvider interface {
// Management server has embedded Dex and can validate tokens via localhost,
// avoiding external network calls and DNS resolution issues during startup.
GetLocalKeysLocation() string
// GetKeyFetcher returns a KeyFetcher that reads keys directly from the IDP storage,
// or nil if direct key fetching is not supported (falls back to HTTP).
GetKeyFetcher() nbjwt.KeyFetcher
GetClientIDs() []string
GetUserIDClaim() string
GetTokenEndpoint() string
@@ -593,6 +600,11 @@ func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string {
return m.config.CLIRedirectURIs
}
// GetKeyFetcher returns a KeyFetcher that reads keys directly from Dex storage.
func (m *EmbeddedIdPManager) GetKeyFetcher() nbjwt.KeyFetcher {
return m.provider.GetJWKS
}
// GetKeysLocation returns the JWKS endpoint URL for token validation.
func (m *EmbeddedIdPManager) GetKeysLocation() string {
return m.provider.GetKeysLocation()

View File

@@ -0,0 +1,235 @@
// Package migration provides utility functions for migrating from the external IdP solution in pre v0.62.0
// to the new embedded IdP manager (Dex based), which is the default in v0.62.0 and later.
// It includes functions to seed connectors and migrate existing users to use these connectors.
package migration
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"os"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
)
// Server is the dependency interface that migration functions use to access
// the main data store and the activity event store.
type Server interface {
Store() Store
EventStore() EventStore // may return nil
}
const idpSeedInfoKey = "IDP_SEED_INFO"
const dryRunEnvKey = "NB_IDP_MIGRATION_DRY_RUN"
func isDryRun() bool {
return os.Getenv(dryRunEnvKey) == "true"
}
var ErrNoSeedInfo = errors.New("no seed info found in environment")
// SeedConnectorFromEnv reads the IDP_SEED_INFO env var, base64-decodes it,
// and JSON-unmarshals it into a dex.Connector. Returns nil if not set.
func SeedConnectorFromEnv() (*dex.Connector, error) {
val, ok := os.LookupEnv(idpSeedInfoKey)
if !ok || val == "" {
return nil, ErrNoSeedInfo
}
decoded, err := base64.StdEncoding.DecodeString(val)
if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err)
}
var conn dex.Connector
if err := json.Unmarshal(decoded, &conn); err != nil {
return nil, fmt.Errorf("json unmarshal: %w", err)
}
return &conn, nil
}
// MigrateUsersToStaticConnectors re-keys every user ID in the main store (and
// the activity store, if present) so that it encodes the given connector ID,
// skipping users that have already been migrated. Set NB_IDP_MIGRATION_DRY_RUN=true
// to log what would happen without writing any changes.
func MigrateUsersToStaticConnectors(s Server, conn *dex.Connector) error {
ctx := context.Background()
if isDryRun() {
log.Info("[DRY RUN] migration dry-run mode enabled, no changes will be written")
}
users, err := s.Store().ListUsers(ctx)
if err != nil {
return fmt.Errorf("failed to list users: %w", err)
}
// Reconciliation pass: fix activity store for users already migrated in main DB
// but whose activity references may still use old IDs (from a previous partial failure).
if s.EventStore() != nil && !isDryRun() {
if err := reconcileActivityStore(ctx, s.EventStore(), users); err != nil {
return err
}
}
var migratedCount, skippedCount int
for _, user := range users {
_, _, decErr := dex.DecodeDexUserID(user.Id)
if decErr == nil {
skippedCount++
continue
}
newUserID := dex.EncodeDexUserID(user.Id, conn.ID)
if isDryRun() {
log.Infof("[DRY RUN] would migrate user %s -> %s (account: %s)", user.Id, newUserID, user.AccountID)
migratedCount++
continue
}
if err := migrateUser(ctx, s, user.Id, user.AccountID, newUserID); err != nil {
return err
}
migratedCount++
}
if isDryRun() {
log.Infof("[DRY RUN] migration summary: %d users would be migrated, %d already migrated", migratedCount, skippedCount)
} else {
log.Infof("migration complete: %d users migrated, %d already migrated", migratedCount, skippedCount)
}
return nil
}
// reconcileActivityStore updates activity store references for users already migrated
// in the main DB whose activity entries may still use old IDs from a previous partial failure.
func reconcileActivityStore(ctx context.Context, eventStore EventStore, users []*types.User) error {
for _, user := range users {
originalID, _, err := dex.DecodeDexUserID(user.Id)
if err != nil {
// skip users that aren't migrated, they will be handled in the main migration loop
continue
}
if err := eventStore.UpdateUserID(ctx, originalID, user.Id); err != nil {
return fmt.Errorf("reconcile activity store for user %s: %w", user.Id, err)
}
}
return nil
}
// migrateUser updates a single user's ID in both the main store and the activity store.
func migrateUser(ctx context.Context, s Server, oldID, accountID, newID string) error {
if err := s.Store().UpdateUserID(ctx, accountID, oldID, newID); err != nil {
return fmt.Errorf("failed to update user ID for user %s: %w", oldID, err)
}
if s.EventStore() == nil {
return nil
}
if err := s.EventStore().UpdateUserID(ctx, oldID, newID); err != nil {
return fmt.Errorf("failed to update activity store user ID for user %s: %w", oldID, err)
}
return nil
}
// PopulateUserInfo fetches user email and name from the external IDP and updates
// the store for users that are missing this information.
func PopulateUserInfo(s Server, idpManager idp.Manager, dryRun bool) error {
ctx := context.Background()
users, err := s.Store().ListUsers(ctx)
if err != nil {
return fmt.Errorf("failed to list users: %w", err)
}
// Build a map of IDP user ID -> UserData from the external IDP
allAccounts, err := idpManager.GetAllAccounts(ctx)
if err != nil {
return fmt.Errorf("failed to fetch accounts from IDP: %w", err)
}
idpUsers := make(map[string]*idp.UserData)
for _, accountUsers := range allAccounts {
for _, userData := range accountUsers {
idpUsers[userData.ID] = userData
}
}
log.Infof("fetched %d users from IDP", len(idpUsers))
var updatedCount, skippedCount, notFoundCount int
for _, user := range users {
if user.IsServiceUser {
skippedCount++
continue
}
if user.Email != "" && user.Name != "" {
skippedCount++
continue
}
// The user ID in the store may be the original IDP ID or a Dex-encoded ID.
// Try to decode the Dex format first to get the original IDP ID.
lookupID := user.Id
if originalID, _, decErr := dex.DecodeDexUserID(user.Id); decErr == nil {
lookupID = originalID
}
idpUser, found := idpUsers[lookupID]
if !found {
notFoundCount++
log.Debugf("user %s (lookup: %s) not found in IDP, skipping", user.Id, lookupID)
continue
}
email := user.Email
name := user.Name
if email == "" && idpUser.Email != "" {
email = idpUser.Email
}
if name == "" && idpUser.Name != "" {
name = idpUser.Name
}
if email == user.Email && name == user.Name {
skippedCount++
continue
}
if dryRun {
log.Infof("[DRY RUN] would update user %s: email=%q, name=%q", user.Id, email, name)
updatedCount++
continue
}
if err := s.Store().UpdateUserInfo(ctx, user.Id, email, name); err != nil {
return fmt.Errorf("failed to update user info for %s: %w", user.Id, err)
}
log.Infof("updated user %s: email=%q, name=%q", user.Id, email, name)
updatedCount++
}
if dryRun {
log.Infof("[DRY RUN] user info summary: %d would be updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount)
} else {
log.Infof("user info population complete: %d updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount)
}
return nil
}

View File

@@ -0,0 +1,828 @@
package migration
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
)
// testStore is a hand-written mock for MigrationStore.
type testStore struct {
listUsersFunc func(ctx context.Context) ([]*types.User, error)
updateUserIDFunc func(ctx context.Context, accountID, oldUserID, newUserID string) error
updateUserInfoFunc func(ctx context.Context, userID, email, name string) error
checkSchemaFunc func(checks []SchemaCheck) []SchemaError
updateCalls []updateUserIDCall
updateInfoCalls []updateUserInfoCall
}
type updateUserIDCall struct {
AccountID string
OldUserID string
NewUserID string
}
type updateUserInfoCall struct {
UserID string
Email string
Name string
}
func (s *testStore) ListUsers(ctx context.Context) ([]*types.User, error) {
return s.listUsersFunc(ctx)
}
func (s *testStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error {
s.updateCalls = append(s.updateCalls, updateUserIDCall{accountID, oldUserID, newUserID})
return s.updateUserIDFunc(ctx, accountID, oldUserID, newUserID)
}
func (s *testStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error {
s.updateInfoCalls = append(s.updateInfoCalls, updateUserInfoCall{userID, email, name})
if s.updateUserInfoFunc != nil {
return s.updateUserInfoFunc(ctx, userID, email, name)
}
return nil
}
func (s *testStore) CheckSchema(checks []SchemaCheck) []SchemaError {
if s.checkSchemaFunc != nil {
return s.checkSchemaFunc(checks)
}
return nil
}
type testServer struct {
store Store
eventStore EventStore
}
func (s *testServer) Store() Store { return s.store }
func (s *testServer) EventStore() EventStore { return s.eventStore }
func TestSeedConnectorFromEnv(t *testing.T) {
t.Run("returns ErrNoSeedInfo when env var is not set", func(t *testing.T) {
os.Unsetenv(idpSeedInfoKey)
conn, err := SeedConnectorFromEnv()
assert.ErrorIs(t, err, ErrNoSeedInfo)
assert.Nil(t, conn)
})
t.Run("returns ErrNoSeedInfo when env var is empty", func(t *testing.T) {
t.Setenv(idpSeedInfoKey, "")
conn, err := SeedConnectorFromEnv()
assert.ErrorIs(t, err, ErrNoSeedInfo)
assert.Nil(t, conn)
})
t.Run("returns error on invalid base64", func(t *testing.T) {
t.Setenv(idpSeedInfoKey, "not-valid-base64!!!")
conn, err := SeedConnectorFromEnv()
assert.NotErrorIs(t, err, ErrNoSeedInfo)
assert.Error(t, err)
assert.Nil(t, conn)
assert.Contains(t, err.Error(), "base64 decode")
})
t.Run("returns error on invalid JSON", func(t *testing.T) {
encoded := base64.StdEncoding.EncodeToString([]byte("not json"))
t.Setenv(idpSeedInfoKey, encoded)
conn, err := SeedConnectorFromEnv()
assert.NotErrorIs(t, err, ErrNoSeedInfo)
assert.Error(t, err)
assert.Nil(t, conn)
assert.Contains(t, err.Error(), "json unmarshal")
})
t.Run("successfully decodes valid connector", func(t *testing.T) {
expected := dex.Connector{
Type: "oidc",
Name: "Test Provider",
ID: "test-provider",
Config: map[string]any{
"issuer": "https://example.com",
"clientID": "my-client-id",
"clientSecret": "my-secret",
},
}
data, err := json.Marshal(expected)
require.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
t.Setenv(idpSeedInfoKey, encoded)
conn, err := SeedConnectorFromEnv()
assert.NoError(t, err)
require.NotNil(t, conn)
assert.Equal(t, expected.Type, conn.Type)
assert.Equal(t, expected.Name, conn.Name)
assert.Equal(t, expected.ID, conn.ID)
assert.Equal(t, expected.Config["issuer"], conn.Config["issuer"])
})
}
func TestMigrateUsersToStaticConnectors(t *testing.T) {
connector := &dex.Connector{
Type: "oidc",
Name: "Test Provider",
ID: "test-connector",
}
t.Run("succeeds with no users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil },
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil },
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
})
t.Run("returns error when ListUsers fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return nil, fmt.Errorf("db error")
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil },
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to list users")
})
t.Run("migrates single user with correct encoded ID", func(t *testing.T) {
user := &types.User{Id: "user-1", AccountID: "account-1"}
expectedNewID := dex.EncodeDexUserID("user-1", "test-connector")
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{user}, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
require.Len(t, ms.updateCalls, 1)
assert.Equal(t, "account-1", ms.updateCalls[0].AccountID)
assert.Equal(t, "user-1", ms.updateCalls[0].OldUserID)
assert.Equal(t, expectedNewID, ms.updateCalls[0].NewUserID)
})
t.Run("migrates multiple users", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
{Id: "user-3", AccountID: "account-2"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 3)
})
t.Run("returns error when UpdateUserID fails", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
callCount := 0
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
callCount++
if callCount == 2 {
return fmt.Errorf("update failed")
}
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to update user ID for user user-2")
})
t.Run("stops on first UpdateUserID error", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return fmt.Errorf("update failed")
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Len(t, ms.updateCalls, 1) // stopped after first error
})
t.Run("skips already migrated users", func(t *testing.T) {
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 0)
})
t.Run("migrates only non-migrated users in mixed state", func(t *testing.T) {
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
{Id: "user-3", AccountID: "account-2"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
// Only user-2 and user-3 should be migrated
assert.Len(t, ms.updateCalls, 2)
assert.Equal(t, "user-2", ms.updateCalls[0].OldUserID)
assert.Equal(t, "user-3", ms.updateCalls[1].OldUserID)
})
t.Run("dry run does not call UpdateUserID", func(t *testing.T) {
t.Setenv(dryRunEnvKey, "true")
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
t.Fatal("UpdateUserID should not be called in dry-run mode")
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 0)
})
t.Run("dry run skips already migrated users", func(t *testing.T) {
t.Setenv(dryRunEnvKey, "true")
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
t.Fatal("UpdateUserID should not be called in dry-run mode")
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
})
t.Run("dry run disabled by default", func(t *testing.T) {
user := &types.User{Id: "user-1", AccountID: "account-1"}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{user}, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 1) // proves it's not in dry-run
})
}
func TestPopulateUserInfo(t *testing.T) {
noopUpdateID := func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil }
t.Run("succeeds with no users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil },
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("returns error when ListUsers fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return nil, fmt.Errorf("db error")
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to list users")
})
t.Run("returns error when GetAllAccounts fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{{Id: "user-1", AccountID: "acc-1"}}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return nil, fmt.Errorf("idp error")
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to fetch accounts from IDP")
})
t.Run("updates user with missing email and name", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "user1@example.com", Name: "User One"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID)
assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "User One", ms.updateInfoCalls[0].Name)
})
t.Run("updates only missing email when name exists", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: "Existing Name"},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "user1@example.com", Name: "IDP Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "Existing Name", ms.updateInfoCalls[0].Name)
})
t.Run("updates only missing name when email exists", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "existing@example.com", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "idp@example.com", Name: "IDP Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "existing@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "IDP Name", ms.updateInfoCalls[0].Name)
})
t.Run("skips users that already have both email and name", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "user1@example.com", Name: "User One"},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "different@example.com", Name: "Different Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips service users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "svc-1", AccountID: "acc-1", Email: "", Name: "", IsServiceUser: true},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "svc-1", Email: "svc@example.com", Name: "Service"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips users not found in IDP", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "different-user", Email: "other@example.com", Name: "Other"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("looks up dex-encoded user IDs by original ID", func(t *testing.T) {
dexEncodedID := dex.EncodeDexUserID("original-idp-id", "my-connector")
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: dexEncodedID, AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "original-idp-id", Email: "user@example.com", Name: "User"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, dexEncodedID, ms.updateInfoCalls[0].UserID)
assert.Equal(t, "user@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "User", ms.updateInfoCalls[0].Name)
})
t.Run("handles multiple users across multiple accounts", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "already@set.com", Name: "Already Set"},
{Id: "user-3", AccountID: "acc-2", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "User 1"},
{ID: "user-2", Email: "u2@example.com", Name: "User 2"},
},
"acc-2": {
{ID: "user-3", Email: "u3@example.com", Name: "User 3"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 2)
assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID)
assert.Equal(t, "u1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "user-3", ms.updateInfoCalls[1].UserID)
assert.Equal(t, "u3@example.com", ms.updateInfoCalls[1].Email)
})
t.Run("returns error when UpdateUserInfo fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
return fmt.Errorf("db write error")
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "u1@example.com", Name: "User 1"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to update user info for user-1")
})
t.Run("stops on first UpdateUserInfo error", func(t *testing.T) {
callCount := 0
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
callCount++
return fmt.Errorf("db write error")
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "U1"},
{ID: "user-2", Email: "u2@example.com", Name: "U2"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Equal(t, 1, callCount)
})
t.Run("dry run does not call UpdateUserInfo", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
t.Fatal("UpdateUserInfo should not be called in dry-run mode")
return nil
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "U1"},
{ID: "user-2", Email: "u2@example.com", Name: "U2"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, true)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips user when IDP has empty email and name too", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "", Name: ""}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
}
func TestSchemaError_String(t *testing.T) {
t.Run("missing table", func(t *testing.T) {
e := SchemaError{Table: "jobs"}
assert.Equal(t, `table "jobs" is missing`, e.String())
})
t.Run("missing column", func(t *testing.T) {
e := SchemaError{Table: "users", Column: "email"}
assert.Equal(t, `column "email" on table "users" is missing`, e.String())
})
}
func TestRequiredSchema(t *testing.T) {
// Verify RequiredSchema covers all the tables touched by UpdateUserID and UpdateUserInfo.
expectedTables := []string{
"users",
"personal_access_tokens",
"peers",
"accounts",
"user_invites",
"proxy_access_tokens",
"jobs",
}
schemaTableNames := make([]string, len(RequiredSchema))
for i, s := range RequiredSchema {
schemaTableNames[i] = s.Table
}
for _, expected := range expectedTables {
assert.Contains(t, schemaTableNames, expected, "RequiredSchema should include table %q", expected)
}
}
func TestCheckSchema_MockStore(t *testing.T) {
t.Run("returns nil when all schema exists", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return nil
},
}
errs := ms.CheckSchema(RequiredSchema)
assert.Empty(t, errs)
})
t.Run("returns errors for missing tables", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return []SchemaError{
{Table: "jobs"},
{Table: "proxy_access_tokens"},
}
},
}
errs := ms.CheckSchema(RequiredSchema)
require.Len(t, errs, 2)
assert.Equal(t, "jobs", errs[0].Table)
assert.Equal(t, "", errs[0].Column)
assert.Equal(t, "proxy_access_tokens", errs[1].Table)
})
t.Run("returns errors for missing columns", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return []SchemaError{
{Table: "users", Column: "email"},
{Table: "users", Column: "name"},
}
},
}
errs := ms.CheckSchema(RequiredSchema)
require.Len(t, errs, 2)
assert.Equal(t, "users", errs[0].Table)
assert.Equal(t, "email", errs[0].Column)
})
}

View File

@@ -0,0 +1,82 @@
package migration
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/server/types"
)
// SchemaCheck represents a table and the columns required on it.
type SchemaCheck struct {
Table string
Columns []string
}
// RequiredSchema lists all tables and columns that the migration tool needs.
// If any are missing, the user must upgrade their management server first so
// that the automatic GORM migrations create them.
var RequiredSchema = []SchemaCheck{
{Table: "users", Columns: []string{"id", "email", "name", "account_id"}},
{Table: "personal_access_tokens", Columns: []string{"user_id", "created_by"}},
{Table: "peers", Columns: []string{"user_id"}},
{Table: "accounts", Columns: []string{"created_by"}},
{Table: "user_invites", Columns: []string{"created_by"}},
{Table: "proxy_access_tokens", Columns: []string{"created_by"}},
{Table: "jobs", Columns: []string{"triggered_by"}},
}
// SchemaError describes a single missing table or column.
type SchemaError struct {
Table string
Column string // empty when the whole table is missing
}
func (e SchemaError) String() string {
if e.Column == "" {
return fmt.Sprintf("table %q is missing", e.Table)
}
return fmt.Sprintf("column %q on table %q is missing", e.Column, e.Table)
}
// Store defines the data store operations required for IdP user migration.
// This interface is separate from the main store.Store interface because these methods
// are only used during one-time migration and should be removed once migration tooling
// is no longer needed.
//
// The SQL store implementations (SqlStore) already have these methods on their concrete
// types, so they satisfy this interface via Go's structural typing with zero code changes.
type Store interface {
// ListUsers returns all users across all accounts.
ListUsers(ctx context.Context) ([]*types.User, error)
// UpdateUserID atomically updates a user's ID and all foreign key references
// across the database (peers, groups, policies, PATs, etc.).
UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error
// UpdateUserInfo updates a user's email and name in the store.
UpdateUserInfo(ctx context.Context, userID, email, name string) error
// CheckSchema verifies that all tables and columns required by the migration
// exist in the database. Returns a list of problems; an empty slice means OK.
CheckSchema(checks []SchemaCheck) []SchemaError
}
// RequiredEventSchema lists all tables and columns that the migration tool needs
// in the activity/event store.
var RequiredEventSchema = []SchemaCheck{
{Table: "events", Columns: []string{"initiator_id", "target_id"}},
{Table: "deleted_users", Columns: []string{"id"}},
}
// EventStore defines the activity event store operations required for migration.
// Like Store, this is a temporary interface for migration tooling only.
type EventStore interface {
// CheckSchema verifies that all tables and columns required by the migration
// exist in the event database. Returns a list of problems; an empty slice means OK.
CheckSchema(checks []SchemaCheck) []SchemaError
// UpdateUserID updates all event references (initiator_id, target_id) and
// deleted_users records to use the new user ID format.
UpdateUserID(ctx context.Context, oldUserID, newUserID string) error
}

View File

@@ -64,10 +64,19 @@ type Manager interface {
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
}
type instanceStore interface {
GetAccountsCounter(ctx context.Context) (int64, error)
}
type embeddedIdP interface {
CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error)
GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error)
}
// DefaultManager is the default implementation of Manager.
type DefaultManager struct {
store store.Store
embeddedIdpManager *idp.EmbeddedIdPManager
store instanceStore
embeddedIdpManager embeddedIdP
setupRequired bool
setupMu sync.RWMutex
@@ -82,18 +91,18 @@ type DefaultManager struct {
// NewManager creates a new instance manager.
// If idpManager is not an EmbeddedIdPManager, setup-related operations will return appropriate defaults.
func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) (Manager, error) {
embeddedIdp, _ := idpManager.(*idp.EmbeddedIdPManager)
embeddedIdp, ok := idpManager.(*idp.EmbeddedIdPManager)
m := &DefaultManager{
store: store,
embeddedIdpManager: embeddedIdp,
setupRequired: false,
store: store,
setupRequired: false,
httpClient: &http.Client{
Timeout: httpTimeout,
},
}
if embeddedIdp != nil {
if ok && embeddedIdp != nil {
m.embeddedIdpManager = embeddedIdp
err := m.loadSetupRequired(ctx)
if err != nil {
return nil, err
@@ -143,36 +152,61 @@ func (m *DefaultManager) IsSetupRequired(_ context.Context) (bool, error) {
// CreateOwnerUser creates the initial owner user in the embedded IDP.
func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if err := m.validateSetupInfo(email, password, name); err != nil {
return nil, err
}
if m.embeddedIdpManager == nil {
return nil, errors.New("embedded IDP is not enabled")
}
m.setupMu.RLock()
setupRequired := m.setupRequired
m.setupMu.RUnlock()
if err := m.validateSetupInfo(email, password, name); err != nil {
return nil, err
}
if !setupRequired {
m.setupMu.Lock()
defer m.setupMu.Unlock()
if !m.setupRequired {
return nil, status.Errorf(status.PreconditionFailed, "setup already completed")
}
if err := m.checkSetupRequiredFromDB(ctx); err != nil {
var sErr *status.Error
if errors.As(err, &sErr) && sErr.Type() == status.PreconditionFailed {
m.setupRequired = false
}
return nil, err
}
userData, err := m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name)
if err != nil {
return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err)
}
m.setupMu.Lock()
m.setupRequired = false
m.setupMu.Unlock()
log.WithContext(ctx).Infof("created owner user %s in embedded IdP", email)
return userData, nil
}
func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error {
numAccounts, err := m.store.GetAccountsCounter(ctx)
if err != nil {
return fmt.Errorf("failed to check accounts: %w", err)
}
if numAccounts > 0 {
return status.Errorf(status.PreconditionFailed, "setup already completed")
}
users, err := m.embeddedIdpManager.GetAllAccounts(ctx)
if err != nil {
return fmt.Errorf("failed to check IdP users: %w", err)
}
if len(users) > 0 {
return status.Errorf(status.PreconditionFailed, "setup already completed")
}
return nil
}
func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
if email == "" {
return status.Errorf(status.InvalidArgument, "email is required")
@@ -189,6 +223,9 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
if len(password) < 8 {
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters")
}
if len(password) > 72 {
return status.Errorf(status.InvalidArgument, "password must be at most 72 characters")
}
return nil
}

View File

@@ -3,7 +3,12 @@ package instance
import (
"context"
"errors"
"fmt"
"net/http"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -11,173 +16,215 @@ import (
"github.com/netbirdio/netbird/management/server/idp"
)
// mockStore implements a minimal store.Store for testing
type mockIdP struct {
mu sync.Mutex
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
users map[string][]*idp.UserData
getAllAccountsErr error
}
func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createUserFunc != nil {
return m.createUserFunc(ctx, email, password, name)
}
return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil
}
func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.getAllAccountsErr != nil {
return nil, m.getAllAccountsErr
}
return m.users, nil
}
type mockStore struct {
accountsCount int64
err error
}
func (m *mockStore) GetAccountsCounter(ctx context.Context) (int64, error) {
func (m *mockStore) GetAccountsCounter(_ context.Context) (int64, error) {
if m.err != nil {
return 0, m.err
}
return m.accountsCount, nil
}
// mockEmbeddedIdPManager wraps the real EmbeddedIdPManager for testing
type mockEmbeddedIdPManager struct {
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
}
func (m *mockEmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createUserFunc != nil {
return m.createUserFunc(ctx, email, password, name)
func newTestManager(idpMock *mockIdP, storeMock *mockStore) *DefaultManager {
return &DefaultManager{
store: storeMock,
embeddedIdpManager: idpMock,
setupRequired: true,
httpClient: &http.Client{Timeout: httpTimeout},
}
return &idp.UserData{
ID: "test-user-id",
Email: email,
Name: name,
}, nil
}
// testManager is a test implementation that accepts our mock types
type testManager struct {
store *mockStore
embeddedIdpManager *mockEmbeddedIdPManager
}
func (m *testManager) IsSetupRequired(ctx context.Context) (bool, error) {
if m.embeddedIdpManager == nil {
return false, nil
}
count, err := m.store.GetAccountsCounter(ctx)
if err != nil {
return false, err
}
return count == 0, nil
}
func (m *testManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.embeddedIdpManager == nil {
return nil, errors.New("embedded IDP is not enabled")
}
return m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name)
}
func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: nil, // No embedded IDP
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when embedded IDP is disabled")
}
func TestIsSetupRequired_NoAccounts(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required, "setup should be required when no accounts exist")
}
func TestIsSetupRequired_AccountsExist(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 1},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when accounts exist")
}
func TestIsSetupRequired_MultipleAccounts(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 5},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when multiple accounts exist")
}
func TestIsSetupRequired_StoreError(t *testing.T) {
manager := &testManager{
store: &mockStore{err: errors.New("database error")},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
_, err := manager.IsSetupRequired(context.Background())
assert.Error(t, err, "should return error when store fails")
}
func TestCreateOwnerUser_Success(t *testing.T) {
expectedEmail := "admin@example.com"
expectedName := "Admin User"
expectedPassword := "securepassword123"
idpMock := &mockIdP{}
mgr := newTestManager(idpMock, &mockStore{})
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{
createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
assert.Equal(t, expectedEmail, email)
assert.Equal(t, expectedPassword, password)
assert.Equal(t, expectedName, name)
return &idp.UserData{
ID: "created-user-id",
Email: email,
Name: name,
}, nil
},
},
}
userData, err := manager.CreateOwnerUser(context.Background(), expectedEmail, expectedPassword, expectedName)
userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.NoError(t, err)
assert.Equal(t, "created-user-id", userData.ID)
assert.Equal(t, expectedEmail, userData.Email)
assert.Equal(t, expectedName, userData.Name)
assert.Equal(t, "admin@example.com", userData.Email)
_, err = mgr.CreateOwnerUser(context.Background(), "admin2@example.com", "password123", "Admin2")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_SetupAlreadyCompleted(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{})
mgr.setupRequired = false
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_EmbeddedIdPDisabled(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: nil,
}
mgr := &DefaultManager{setupRequired: true}
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
assert.Error(t, err, "should return error when embedded IDP is disabled")
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "embedded IDP is not enabled")
}
func TestCreateOwnerUser_IdPError(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{
createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return nil, errors.New("user already exists")
},
idpMock := &mockIdP{
createUserFunc: func(_ context.Context, _, _, _ string) (*idp.UserData, error) {
return nil, errors.New("provider error")
},
}
mgr := newTestManager(idpMock, &mockStore{})
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
assert.Error(t, err, "should return error when IDP fails")
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "provider error")
required, _ := mgr.IsSetupRequired(context.Background())
assert.True(t, required, "setup should still be required after IdP error")
}
func TestCreateOwnerUser_TransientDBError_DoesNotBlockSetup(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{err: errors.New("connection refused")})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "connection refused")
required, _ := mgr.IsSetupRequired(context.Background())
assert.True(t, required, "setup should still be required after transient DB error")
mgr.store = &mockStore{}
userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.NoError(t, err)
assert.Equal(t, "admin@example.com", userData.Email)
}
func TestCreateOwnerUser_TransientIdPError_DoesNotBlockSetup(t *testing.T) {
idpMock := &mockIdP{getAllAccountsErr: errors.New("connection reset")}
mgr := newTestManager(idpMock, &mockStore{})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "connection reset")
required, _ := mgr.IsSetupRequired(context.Background())
assert.True(t, required, "setup should still be required after transient IdP error")
idpMock.getAllAccountsErr = nil
userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.NoError(t, err)
assert.Equal(t, "admin@example.com", userData.Email)
}
func TestCreateOwnerUser_DBCheckBlocksConcurrent(t *testing.T) {
idpMock := &mockIdP{
users: map[string][]*idp.UserData{
"acc1": {{ID: "existing-user"}},
},
}
mgr := newTestManager(idpMock, &mockStore{})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_DBCheckBlocksWhenAccountsExist(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{accountsCount: 1})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_ConcurrentRequests(t *testing.T) {
var idpCallCount atomic.Int32
var successCount atomic.Int32
var failCount atomic.Int32
idpMock := &mockIdP{
createUserFunc: func(_ context.Context, email, _, _ string) (*idp.UserData, error) {
idpCallCount.Add(1)
time.Sleep(50 * time.Millisecond)
return &idp.UserData{ID: "user-1", Email: email, Name: "Owner"}, nil
},
}
mgr := newTestManager(idpMock, &mockStore{})
var wg sync.WaitGroup
for i := range 10 {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, err := mgr.CreateOwnerUser(
context.Background(),
fmt.Sprintf("owner%d@example.com", idx),
"password1234",
fmt.Sprintf("Owner%d", idx),
)
if err != nil {
failCount.Add(1)
} else {
successCount.Add(1)
}
}(i)
}
wg.Wait()
assert.Equal(t, int32(1), successCount.Load(), "exactly one concurrent setup request should succeed")
assert.Equal(t, int32(9), failCount.Load(), "remaining concurrent requests should fail")
assert.Equal(t, int32(1), idpCallCount.Load(), "IdP CreateUser should be called exactly once")
}
func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) {
mgr := &DefaultManager{}
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required)
}
func TestIsSetupRequired_ReturnsFlag(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{})
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required)
mgr.setupMu.Lock()
mgr.setupRequired = false
mgr.setupMu.Unlock()
required, err = mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required)
}
func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
manager := &DefaultManager{
setupRequired: true,
}
manager := &DefaultManager{setupRequired: true}
tests := []struct {
name string
@@ -188,11 +235,10 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
errorMsg string
}{
{
name: "valid request",
email: "admin@example.com",
password: "password123",
userName: "Admin User",
expectError: false,
name: "valid request",
email: "admin@example.com",
password: "password123",
userName: "Admin User",
},
{
name: "empty email",
@@ -235,11 +281,24 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
errorMsg: "password must be at least 8 characters",
},
{
name: "password exactly 8 characters",
name: "password exactly 8 characters",
email: "admin@example.com",
password: "12345678",
userName: "Admin User",
},
{
name: "password exactly 72 characters",
email: "admin@example.com",
password: "aaaaaaaabbbbbbbbccccccccddddddddeeeeeeeeffffffffgggggggghhhhhhhhiiiiiiii",
userName: "Admin User",
},
{
name: "password too long",
email: "admin@example.com",
password: "12345678",
password: "aaaaaaaabbbbbbbbccccccccddddddddeeeeeeeeffffffffgggggggghhhhhhhhiiiiiiiij",
userName: "Admin User",
expectError: false,
expectError: true,
errorMsg: "password must be at most 72 characters",
},
}
@@ -255,14 +314,3 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
})
}
}
func TestDefaultManager_CreateOwnerUser_SetupAlreadyCompleted(t *testing.T) {
manager := &DefaultManager{
setupRequired: false,
embeddedIdpManager: &idp.EmbeddedIdPManager{},
}
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}

View File

@@ -46,7 +46,7 @@ type MockAccountManager struct {
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
@@ -406,9 +406,9 @@ func (am *MockAccountManager) AddPeer(
}
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
if am.GetGroupByNameFunc != nil {
return am.GetGroupByNameFunc(ctx, accountID, groupName)
return am.GetGroupByNameFunc(ctx, groupName, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
}

View File

@@ -859,7 +859,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if !temporary {
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
}
if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
@@ -1480,9 +1482,11 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil {
return nil, err
}
peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") {
peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
}
}
return peerDeletedEvents, nil

View File

@@ -84,7 +84,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
// DeletePostureChecks deletes a posture check by ID.
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}

View File

@@ -2080,7 +2080,8 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key
pass_host_header, rewrite_redirects, session_private_key, session_public_key,
mode, listen_port, port_auto_assigned, source, source_peer, terminated
FROM services WHERE account_id = $1`
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
@@ -2097,6 +2098,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
var auth []byte
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
var mode, source, sourcePeer sql.NullString
err := row.Scan(
&s.ID,
&s.AccountID,
@@ -2112,6 +2114,12 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
&s.RewriteRedirects,
&sessionPrivateKey,
&sessionPublicKey,
&mode,
&s.ListenPort,
&s.PortAutoAssigned,
&source,
&sourcePeer,
&s.Terminated,
)
if err != nil {
return nil, err
@@ -2143,6 +2151,15 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
if sessionPublicKey.Valid {
s.SessionPublicKey = sessionPublicKey.String
}
if mode.Valid {
s.Mode = mode.String
}
if source.Valid {
s.Source = source.String
}
if sourcePeer.Valid {
s.SourcePeer = sourcePeer.String
}
s.Targets = []*rpservice.Target{}
return &s, nil
@@ -5445,7 +5462,7 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
@@ -5463,7 +5480,7 @@ func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster,
result := s.db.Model(&proxy.Proxy{}).
Select("cluster_address as address, COUNT(*) as connected_proxies").
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Group("cluster_address").
Scan(&clusters)
@@ -5475,6 +5492,63 @@ func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster,
return clusters, nil
}
// proxyActiveThreshold is the maximum age of a heartbeat for a proxy to be
// considered active. Must be at least 2x the heartbeat interval (1 min).
const proxyActiveThreshold = 2 * time.Minute
var validCapabilityColumns = map[string]struct{}{
"supports_custom_ports": {},
"require_subdomain": {},
}
// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster
// supports custom ports. Returns nil when no proxy reported the capability.
func (s *SqlStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
return s.getClusterCapability(ctx, clusterAddr, "supports_custom_ports")
}
// GetClusterRequireSubdomain returns whether any active proxy in the cluster
// requires a subdomain. Returns nil when no proxy reported the capability.
func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
return s.getClusterCapability(ctx, clusterAddr, "require_subdomain")
}
// getClusterCapability returns an aggregated boolean capability for the given
// cluster. It checks active (connected, recently seen) proxies and returns:
// - *true if any proxy in the cluster has the capability set to true,
// - *false if at least one proxy reported but none set it to true,
// - nil if no proxy reported the capability at all.
func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column string) *bool {
if _, ok := validCapabilityColumns[column]; !ok {
log.WithContext(ctx).Errorf("invalid capability column: %s", column)
return nil
}
var result struct {
HasCapability bool
AnyTrue bool
}
err := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").
Where("cluster_address = ? AND status = ? AND last_seen > ?",
clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)).
Scan(&result).Error
if err != nil {
log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err)
return nil
}
if !result.HasCapability {
return nil
}
return &result.AnyTrue
}
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)
@@ -5494,3 +5568,61 @@ func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration t
return nil
}
// GetRoutingPeerNetworks returns the distinct network names where the peer is assigned as a routing peer
// in an enabled network router, either directly or via peer groups.
func (s *SqlStore) GetRoutingPeerNetworks(_ context.Context, accountID, peerID string) ([]string, error) {
var routers []*routerTypes.NetworkRouter
if err := s.db.Select("peer, peer_groups, network_id").Where("account_id = ? AND enabled = true", accountID).Find(&routers).Error; err != nil {
return nil, status.Errorf(status.Internal, "failed to get enabled routers: %v", err)
}
if len(routers) == 0 {
return nil, nil
}
var groupPeers []types.GroupPeer
if err := s.db.Select("group_id").Where("account_id = ? AND peer_id = ?", accountID, peerID).Find(&groupPeers).Error; err != nil {
return nil, status.Errorf(status.Internal, "failed to get peer group memberships: %v", err)
}
groupSet := make(map[string]struct{}, len(groupPeers))
for _, gp := range groupPeers {
groupSet[gp.GroupID] = struct{}{}
}
networkIDs := make(map[string]struct{})
for _, r := range routers {
if r.Peer == peerID {
networkIDs[r.NetworkID] = struct{}{}
} else if r.Peer == "" {
for _, pg := range r.PeerGroups {
if _, ok := groupSet[pg]; ok {
networkIDs[r.NetworkID] = struct{}{}
break
}
}
}
}
if len(networkIDs) == 0 {
return nil, nil
}
ids := make([]string, 0, len(networkIDs))
for id := range networkIDs {
ids = append(ids, id)
}
var networks []*networkTypes.Network
if err := s.db.Select("name").Where("account_id = ? AND id IN ?", accountID, ids).Find(&networks).Error; err != nil {
return nil, status.Errorf(status.Internal, "failed to get networks: %v", err)
}
names := make([]string, 0, len(networks))
for _, n := range networks {
names = append(names, n.Name)
}
return names, nil
}

Some files were not shown because too many files have changed in this diff Show More