Merge remote-tracking branch 'origin/main' into iptables-mangle-dnat-guard

This commit is contained in:
Viktor Liu
2026-04-08 07:50:39 +02:00
166 changed files with 18492 additions and 2122 deletions

View File

@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -61,8 +61,8 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 57671680 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
if [ ${SIZE} -gt 58720256 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
exit 1
fi

View File

@@ -154,6 +154,26 @@ builds:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-idp-migrate
dir: tools/idp-migrate
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-idp-migrate
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
@@ -166,6 +186,10 @@ archives:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
- id: netbird-idp-migrate
builds:
- netbird-idp-migrate
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
nfpms:
- maintainer: Netbird <dev@netbird.io>

View File

@@ -1,7 +1,7 @@
## Contributor License Agreement
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany,
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance

View File

@@ -205,7 +205,7 @@ func (c *Client) PeersList() *PeerInfoArray {
pi := PeerInfo{
p.IP,
p.FQDN,
p.ConnStatus.String(),
int(p.ConnStatus),
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
}
peerInfos[n] = pi

View File

@@ -2,11 +2,20 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
// Connection status constants exported via gomobile.
const (
ConnStatusIdle = int(peer.StatusIdle)
ConnStatusConnecting = int(peer.StatusConnecting)
ConnStatusConnected = int(peer.StatusConnected)
)
// PeerInfo describe information about the peers. It designed for the UI usage
type PeerInfo struct {
IP string
FQDN string
ConnStatus string // Todo replace to enum
ConnStatus int
Routes PeerRoutes
}

View File

@@ -14,7 +14,9 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
@@ -200,7 +202,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
stream, err := client.ExposeService(ctx, req)
if err != nil {
return fmt.Errorf("expose service: %w", err)
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
}
if err := handleExposeReady(cmd, stream, port); err != nil {
@@ -211,26 +213,31 @@ func exposeFn(cmd *cobra.Command, args []string) error {
}
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
switch strings.ToLower(exposeProtocol) {
case "http":
p, err := expose.ParseProtocolType(exposeProtocol)
if err != nil {
return 0, fmt.Errorf("invalid protocol: %w", err)
}
switch p {
case expose.ProtocolHTTP:
return proto.ExposeProtocol_EXPOSE_HTTP, nil
case "https":
case expose.ProtocolHTTPS:
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
case "tcp":
case expose.ProtocolTCP:
return proto.ExposeProtocol_EXPOSE_TCP, nil
case "udp":
case expose.ProtocolUDP:
return proto.ExposeProtocol_EXPOSE_UDP, nil
case "tls":
case expose.ProtocolTLS:
return proto.ExposeProtocol_EXPOSE_TLS, nil
default:
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
return 0, fmt.Errorf("unhandled protocol type: %d", p)
}
}
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %w", err)
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
}
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)

View File

@@ -41,7 +41,7 @@ func init() {
defaultServiceName = "Netbird"
}
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")

View File

@@ -119,6 +119,10 @@ var installCmd = &cobra.Command{
return err
}
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
svcConfig, err := createServiceConfigForInstall()
if err != nil {
return err
@@ -136,6 +140,10 @@ var installCmd = &cobra.Command{
return fmt.Errorf("install service: %w", err)
}
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
cmd.Println("NetBird service has been installed")
return nil
},
@@ -187,6 +195,10 @@ This command will temporarily stop the service, update its configuration, and re
return err
}
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
wasRunning, err := isServiceRunning()
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
return fmt.Errorf("check service status: %w", err)
@@ -222,6 +234,10 @@ This command will temporarily stop the service, update its configuration, and re
return fmt.Errorf("install service with new config: %w", err)
}
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
if wasRunning {
cmd.Println("Starting NetBird service...")
if err := s.Start(); err != nil {

View File

@@ -0,0 +1,201 @@
//go:build !ios && !android
package cmd
import (
"context"
"encoding/json"
"fmt"
"maps"
"os"
"path/filepath"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/util"
)
const serviceParamsFile = "service.json"
// serviceParams holds install-time service parameters that persist across
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
type serviceParams struct {
LogLevel string `json:"log_level"`
DaemonAddr string `json:"daemon_addr"`
ManagementURL string `json:"management_url,omitempty"`
ConfigPath string `json:"config_path,omitempty"`
LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
// serviceParamsPath returns the path to the service params file.
func serviceParamsPath() string {
return filepath.Join(configs.StateDir, serviceParamsFile)
}
// loadServiceParams reads saved service parameters from disk.
// Returns nil with no error if the file does not exist.
func loadServiceParams() (*serviceParams, error) {
path := serviceParamsPath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil //nolint:nilnil
}
return nil, fmt.Errorf("read service params %s: %w", path, err)
}
var params serviceParams
if err := json.Unmarshal(data, &params); err != nil {
return nil, fmt.Errorf("parse service params %s: %w", path, err)
}
return &params, nil
}
// saveServiceParams writes current service parameters to disk atomically
// with restricted permissions.
func saveServiceParams(params *serviceParams) error {
path := serviceParamsPath()
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
return fmt.Errorf("save service params: %w", err)
}
return nil
}
// currentServiceParams captures the current state of all package-level
// variables into a serviceParams struct.
func currentServiceParams() *serviceParams {
params := &serviceParams{
LogLevel: logLevel,
DaemonAddr: daemonAddr,
ManagementURL: managementURL,
ConfigPath: configPath,
LogFiles: logFiles,
DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled,
}
if len(serviceEnvVars) > 0 {
parsed, err := parseServiceEnvVars(serviceEnvVars)
if err == nil && len(parsed) > 0 {
params.ServiceEnvVars = parsed
}
}
return params
}
// loadAndApplyServiceParams loads saved params from disk and applies them
// to any flags that were not explicitly set.
func loadAndApplyServiceParams(cmd *cobra.Command) error {
params, err := loadServiceParams()
if err != nil {
return err
}
applyServiceParams(cmd, params)
return nil
}
// applyServiceParams merges saved parameters into package-level variables
// for any flag that was not explicitly set by the user (via CLI or env var).
// Flags that were Changed() are left untouched.
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
if params == nil {
return
}
// For fields with non-empty defaults (log-level, daemon-addr), keep the
// != "" guard so that an older service.json missing the field doesn't
// clobber the default with an empty string.
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
logLevel = params.LogLevel
}
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
daemonAddr = params.DaemonAddr
}
// For optional fields where empty means "use default", always apply so
// that an explicit clear (--management-url "") persists across reinstalls.
if !rootCmd.PersistentFlags().Changed("management-url") {
managementURL = params.ManagementURL
}
if !rootCmd.PersistentFlags().Changed("config") {
configPath = params.ConfigPath
}
if !rootCmd.PersistentFlags().Changed("log-file") {
logFiles = params.LogFiles
}
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
profilesDisabled = params.DisableProfiles
}
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
updateSettingsDisabled = params.DisableUpdateSettings
}
applyServiceEnvParams(cmd, params)
}
// applyServiceEnvParams merges saved service environment variables.
// If --service-env was explicitly set, explicit values win on key conflict
// but saved keys not in the explicit set are carried over.
// If --service-env was not set, saved env vars are used entirely.
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
if len(params.ServiceEnvVars) == 0 {
return
}
if !cmd.Flags().Changed("service-env") {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
return
}
// Explicit env vars were provided: merge saved values underneath.
explicit, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
return
}
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
maps.Copy(merged, params.ServiceEnvVars)
maps.Copy(merged, explicit) // explicit wins on conflict
serviceEnvVars = envMapToSlice(merged)
}
var resetParamsCmd = &cobra.Command{
Use: "reset-params",
Short: "Remove saved service install parameters",
Long: "Removes the saved service.json file so the next install uses default parameters.",
RunE: func(cmd *cobra.Command, args []string) error {
path := serviceParamsPath()
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
cmd.Println("No saved service parameters found")
return nil
}
return fmt.Errorf("remove service params: %w", err)
}
cmd.Printf("Removed saved service parameters (%s)\n", path)
return nil
},
}
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
func envMapToSlice(m map[string]string) []string {
s := make([]string, 0, len(m))
for k, v := range m {
s = append(s, k+"="+v)
}
return s
}

View File

@@ -0,0 +1,523 @@
//go:build !ios && !android
package cmd
import (
"encoding/json"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"strings"
"testing"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/configs"
)
func TestServiceParamsPath(t *testing.T) {
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = "/var/lib/netbird"
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
configs.StateDir = "/custom/state"
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
}
func TestSaveAndLoadServiceParams(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params := &serviceParams{
LogLevel: "debug",
DaemonAddr: "unix:///var/run/netbird.sock",
ManagementURL: "https://my.server.com",
ConfigPath: "/etc/netbird/config.json",
LogFiles: []string{"/var/log/netbird/client.log", "console"},
DisableProfiles: true,
DisableUpdateSettings: false,
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
}
err := saveServiceParams(params)
require.NoError(t, err)
// Verify the file exists and is valid JSON.
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
require.NoError(t, err)
assert.True(t, json.Valid(data))
loaded, err := loadServiceParams()
require.NoError(t, err)
require.NotNil(t, loaded)
assert.Equal(t, params.LogLevel, loaded.LogLevel)
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
assert.Equal(t, params.LogFiles, loaded.LogFiles)
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
}
func TestLoadServiceParams_FileNotExists(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params, err := loadServiceParams()
assert.NoError(t, err)
assert.Nil(t, params)
}
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
require.NoError(t, err)
params, err := loadServiceParams()
assert.Error(t, err)
assert.Nil(t, params)
}
func TestCurrentServiceParams(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
logLevel = "trace"
daemonAddr = "tcp://127.0.0.1:9999"
managementURL = "https://mgmt.example.com"
configPath = "/tmp/test-config.json"
logFiles = []string{"/tmp/test.log"}
profilesDisabled = true
updateSettingsDisabled = true
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
params := currentServiceParams()
assert.Equal(t, "trace", params.LogLevel)
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
assert.True(t, params.DisableProfiles)
assert.True(t, params.DisableUpdateSettings)
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
}
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
// Reset all flags to defaults.
logLevel = "info"
daemonAddr = "unix:///var/run/netbird.sock"
managementURL = ""
configPath = "/etc/netbird/config.json"
logFiles = []string{"/var/log/netbird/client.log"}
profilesDisabled = false
updateSettingsDisabled = false
serviceEnvVars = nil
// Reset Changed state on all relevant flags.
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Simulate user explicitly setting --log-level via CLI.
logLevel = "warn"
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
saved := &serviceParams{
LogLevel: "debug",
DaemonAddr: "tcp://127.0.0.1:5555",
ManagementURL: "https://saved.example.com",
ConfigPath: "/saved/config.json",
LogFiles: []string{"/saved/client.log"},
DisableProfiles: true,
DisableUpdateSettings: true,
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
// log-level was Changed, so it should keep "warn", not use saved "debug".
assert.Equal(t, "warn", logLevel)
// All other fields were not Changed, so they should use saved values.
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
assert.Equal(t, "https://saved.example.com", managementURL)
assert.Equal(t, "/saved/config.json", configPath)
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
assert.True(t, profilesDisabled)
assert.True(t, updateSettingsDisabled)
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
}
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
t.Cleanup(func() {
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
})
// Simulate current state where booleans are true (e.g. set by previous install).
profilesDisabled = true
updateSettingsDisabled = true
// Reset Changed state so flags appear unset.
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Saved params have both as false.
saved := &serviceParams{
DisableProfiles: false,
DisableUpdateSettings: false,
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.False(t, profilesDisabled, "saved false should override current true")
assert.False(t, updateSettingsDisabled, "saved false should override current true")
}
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
origManagementURL := managementURL
t.Cleanup(func() { managementURL = origManagementURL })
managementURL = "https://leftover.example.com"
// Simulate saved params where management URL was explicitly cleared.
saved := &serviceParams{
LogLevel: "info",
DaemonAddr: "unix:///var/run/netbird.sock",
// ManagementURL intentionally empty: was cleared with --management-url "".
}
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
}
func TestApplyServiceParams_NilParams(t *testing.T) {
origLogLevel := logLevel
t.Cleanup(func() { logLevel = origLogLevel })
logLevel = "info"
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
// Should be a no-op.
applyServiceParams(cmd, nil)
assert.Equal(t, "info", logLevel)
}
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Set up a command with --service-env marked as Changed.
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
saved := &serviceParams{
ServiceEnvVars: map[string]string{
"SAVED": "val",
"OVERLAP": "saved",
},
}
applyServiceEnvParams(cmd, saved)
// Parse result for easier assertion.
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, "yes", result["EXPLICIT"])
assert.Equal(t, "val", result["SAVED"])
// Explicit wins on conflict.
assert.Equal(t, "explicit", result["OVERLAP"])
}
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
serviceEnvVars = nil
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
saved := &serviceParams{
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
}
applyServiceEnvParams(cmd, saved)
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
}
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
// added to serviceParams but not wired into these functions, this test fails.
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
// Collect all JSON field names from the serviceParams struct.
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
// Collect field names referenced in currentServiceParams and applyServiceParams.
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
// applyServiceEnvParams handles ServiceEnvVars indirectly.
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
for k, v := range applyEnvFields {
applyFields[k] = v
}
for _, field := range structFields {
assert.Contains(t, currentFields, field,
"serviceParams field %q is not captured in currentServiceParams()", field)
assert.Contains(t, applyFields, field,
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
}
}
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
// it flows through newSVCConfig() EnvVars, not CLI args.
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields)
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
require.NoError(t, err)
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
fieldsNotInArgs := map[string]bool{
"ServiceEnvVars": true,
}
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
// Forward: every struct field must appear in buildServiceArguments.
for _, field := range structFields {
if fieldsNotInArgs[field] {
continue
}
globalVar := fieldToGlobalVar(field)
assert.Contains(t, buildFields, globalVar,
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
}
// Reverse: every service-related global used in buildServiceArguments must
// have a corresponding serviceParams field. This catches a developer adding
// a new flag to buildServiceArguments without adding it to the struct.
globalToField := make(map[string]string, len(structFields))
for _, field := range structFields {
globalToField[fieldToGlobalVar(field)] = field
}
// Identifiers in buildServiceArguments that are not service params
// (builtins, boilerplate, loop variables).
nonParamGlobals := map[string]bool{
"args": true, "append": true, "string": true, "_": true,
"logFile": true, // range variable over logFiles
}
for ref := range buildFields {
if nonParamGlobals[ref] {
continue
}
_, inStruct := globalToField[ref]
assert.True(t, inStruct,
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
}
}
// extractStructJSONFields returns field names from a named struct type.
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
t.Helper()
var fields []string
ast.Inspect(file, func(n ast.Node) bool {
ts, ok := n.(*ast.TypeSpec)
if !ok || ts.Name.Name != structName {
return true
}
st, ok := ts.Type.(*ast.StructType)
if !ok {
return false
}
for _, f := range st.Fields.List {
if len(f.Names) > 0 {
fields = append(fields, f.Names[0].Name)
}
}
return false
})
return fields
}
// extractFuncFieldRefs returns which of the given field names appear inside the
// named function, either as selector expressions (params.FieldName) or as
// composite literal keys (&serviceParams{FieldName: ...}).
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
t.Helper()
fieldSet := make(map[string]bool, len(fields))
for _, f := range fields {
fieldSet[f] = true
}
found := make(map[string]bool)
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
ast.Inspect(fn.Body, func(n ast.Node) bool {
switch v := n.(type) {
case *ast.SelectorExpr:
if fieldSet[v.Sel.Name] {
found[v.Sel.Name] = true
}
case *ast.KeyValueExpr:
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
found[ident.Name] = true
}
}
return true
})
return found
}
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
t.Helper()
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
refs := make(map[string]bool)
ast.Inspect(fn.Body, func(n ast.Node) bool {
if ident, ok := n.(*ast.Ident); ok {
refs[ident.Name] = true
}
return true
})
return refs
}
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
for _, decl := range file.Decls {
fn, ok := decl.(*ast.FuncDecl)
if ok && fn.Name.Name == name {
return fn
}
}
return nil
}
// fieldToGlobalVar maps serviceParams field names to the package-level variable
// names used in buildServiceArguments and applyServiceParams.
func fieldToGlobalVar(field string) string {
m := map[string]string{
"LogLevel": "logLevel",
"DaemonAddr": "daemonAddr",
"ManagementURL": "managementURL",
"ConfigPath": "configPath",
"LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled",
"ServiceEnvVars": "serviceEnvVars",
}
if v, ok := m[field]; ok {
return v
}
// Default: lowercase first letter.
return strings.ToLower(field[:1]) + field[1:]
}
func TestEnvMapToSlice(t *testing.T) {
m := map[string]string{"A": "1", "B": "2"}
s := envMapToSlice(m)
assert.Len(t, s, 2)
assert.Contains(t, s, "A=1")
assert.Contains(t, s, "B=2")
}
func TestEnvMapToSlice_Empty(t *testing.T) {
s := envMapToSlice(map[string]string{})
assert.Empty(t, s)
}

View File

@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"syscall"
"testing"
"time"
@@ -13,6 +15,22 @@ import (
"github.com/stretchr/testify/require"
)
// TestMain intercepts when this test binary is run as a daemon subprocess.
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
// "service run ..." arguments. Since the test binary can't handle cobra CLI
// args, it exits immediately, causing daemon -r to respawn rapidly until
// hitting the rate limit and exiting. This makes service restart unreliable.
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
<-sig
return
}
os.Exit(m.Run())
}
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
@@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) {
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {

View File

@@ -33,14 +33,14 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
@@ -375,6 +375,32 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL.
// It returns an ExposeSession. Call Wait on the session to keep it alive.
func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
mgr := engine.GetExposeManager()
if mgr == nil {
return nil, fmt.Errorf("expose manager not available")
}
resp, err := mgr.Expose(ctx, req)
if err != nil {
return nil, fmt.Errorf("expose: %w", err)
}
return &ExposeSession{
Domain: resp.Domain,
ServiceName: resp.ServiceName,
ServiceURL: resp.ServiceURL,
mgr: mgr,
}, nil
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()

45
client/embed/expose.go Normal file
View File

@@ -0,0 +1,45 @@
package embed
import (
"context"
"errors"
"github.com/netbirdio/netbird/client/internal/expose"
)
const (
// ExposeProtocolHTTP exposes the service as HTTP.
ExposeProtocolHTTP = expose.ProtocolHTTP
// ExposeProtocolHTTPS exposes the service as HTTPS.
ExposeProtocolHTTPS = expose.ProtocolHTTPS
// ExposeProtocolTCP exposes the service as TCP.
ExposeProtocolTCP = expose.ProtocolTCP
// ExposeProtocolUDP exposes the service as UDP.
ExposeProtocolUDP = expose.ProtocolUDP
// ExposeProtocolTLS exposes the service as TLS.
ExposeProtocolTLS = expose.ProtocolTLS
)
// ExposeRequest is a request to expose a local service via the NetBird reverse proxy.
type ExposeRequest = expose.Request
// ExposeProtocolType represents the protocol used for exposing a service.
type ExposeProtocolType = expose.ProtocolType
// ExposeSession represents an active expose session. Use Wait to block until the session ends.
type ExposeSession struct {
Domain string
ServiceName string
ServiceURL string
mgr *expose.Manager
}
// Wait blocks while keeping the expose session alive.
// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session.
func (s *ExposeSession) Wait(ctx context.Context) error {
if s == nil || s.mgr == nil {
return errors.New("expose session is not initialized")
}
return s.mgr.KeepAlive(ctx, s.Domain)
}

View File

@@ -286,6 +286,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"

View File

@@ -36,6 +36,7 @@ const (
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
chainNATOutput = "NETBIRD-NAT-OUTPUT"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
@@ -43,6 +44,7 @@ const (
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
jumpNatOutput = "jump-nat-output"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
@@ -387,6 +389,14 @@ func (r *router) cleanUpDefaultForwardRules() error {
}
log.Debug("flushing routing related tables")
// Remove jump rules from built-in chains before deleting custom chains,
// otherwise the chain deletion fails with "device or resource busy".
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
log.Debugf("clean OUTPUT jump rule: %v", err)
}
for _, chainInfo := range []struct {
chain string
table string
@@ -396,6 +406,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainNATOutput, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
@@ -970,6 +981,81 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.rules[jumpNatOutput]; exists {
return nil
}
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
}
if !chainExists {
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
}
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
if !chainExists {
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
}
}
return fmt.Errorf("add OUTPUT jump rule: %w", err)
}
r.rules[jumpNatOutput] = jumpRule
r.updateState()
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
dnatRule := []string{
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
r.updateState()
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil

View File

@@ -169,6 +169,14 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error

View File

@@ -346,6 +346,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"

View File

@@ -36,6 +36,7 @@ const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameNATOutput = "netbird-nat-output"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
@@ -1853,6 +1854,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.chains[chainNameNATOutput]; exists {
return nil
}
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
Name: chainNameNATOutput,
Table: r.workTable,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
if err := r.conn.Flush(); err != nil {
delete(r.chains, chainNameNATOutput)
return fmt.Errorf("create NAT output chain: %w", err)
}
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameNATOutput],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,

View File

@@ -140,6 +140,17 @@ type Manager struct {
mtu uint16
mssClampValue uint16
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}
// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
}
// decoder for packages
@@ -594,6 +605,8 @@ func (m *Manager) resetState() {
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
m.tcpHookOut.Store(nil)
if m.udpTracker != nil {
m.udpTracker.Close()
@@ -713,6 +726,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
case layers.LayerTypeTCP:
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
return true
}
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
@@ -895,38 +911,21 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
d.dnatOrigPort = 0
}
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}
// Check specific destination IP first
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
// Check IPv4 unspecified address
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
}
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
}
@@ -1278,12 +1277,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.mgmtId, rule.udpHook(packetData), true
}
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
@@ -1342,65 +1335,30 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return sourceMatched
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
protoLayer: layers.LayerTypeUDP,
dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6,
udpHook: hook,
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.udpHookOut.Store(nil)
return
}
if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
// Incoming UDP hooks are stored in allow rules map
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
}
m.incomingRules[r.ip][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip][r.id] = r
}
m.mutex.Unlock()
return r.id
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// Check incoming hooks (stored in allow rules)
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.tcpHookOut.Store(nil)
return
}
// Check outgoing hooks
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
return fmt.Errorf("hook with given id not found")
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// SetLogLevel sets the log level for the firewall manager

View File

@@ -12,6 +12,7 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
@@ -186,81 +187,52 @@ func TestManagerDeleteRule(t *testing.T) {
}
}
func TestAddUDPPacketHook(t *testing.T) {
tests := []struct {
name string
in bool
expDir fw.RuleDirection
ip netip.Addr
dPort uint16
hook func([]byte) bool
expectedID string
}{
{
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000,
hook: func([]byte) bool { return true },
},
{
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: netip.MustParseAddr("::1"),
dPort: 9000,
hook: func([]byte) bool { return false },
},
}
func TestSetUDPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
var called bool
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
called = true
return true
})
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
var addedRule PeerRule
if tt.in {
// Incoming UDP hooks are stored in allow rules map
if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
return
}
for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule
}
} else {
if len(manager.outgoingRules[tt.ip]) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
return
}
for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule
}
}
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
assert.Nil(t, manager.udpHookOut.Load())
}
if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
if tt.dPort != addedRule.dPort.Values[0] {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
return
}
if layers.LayerTypeUDP != addedRule.protoLayer {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return
}
if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set")
return
}
})
}
func TestSetTCPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
var called bool
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
called = true
return true
})
h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
assert.Nil(t, manager.tcpHookOut.Load())
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
@@ -530,39 +502,12 @@ func TestRemovePacketHook(t *testing.T) {
require.NoError(t, manager.Close(nil))
}()
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
found = true
break
}
}
}
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
if !found {
t.Fatalf("The hook was not added properly.")
}
// Now remove the packet hook
err = manager.RemovePacketHook(hookID)
if err != nil {
t.Fatalf("Failed to remove hook: %s", err)
}
// Assert the hook is removed by checking it in the manager's outgoing rules
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
}
}
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
}
func TestProcessOutgoingHooks(t *testing.T) {
@@ -592,8 +537,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
}
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
manager.SetUDPPacketHook(
netip.MustParseAddr("100.10.0.100"),
53,
func([]byte) bool {
@@ -601,7 +545,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
return true
},
)
require.NotEmpty(t, hookID)
// Create test UDP packet
ipv4 := &layers.IPv4{

View File

@@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
// case where an interface comes up between refreshes.
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}

View File

@@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
// TODO: also delegate to nativeFirewall when available for kernel WG mode
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
@@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// AddOutputDNAT delegates to the native firewall if available.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return fmt.Errorf("output DNAT not supported without native firewall")
}
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT delegates to the native firewall if available.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {

View File

@@ -18,9 +18,7 @@ type PeerRule struct {
protoLayer gopacket.LayerType
sPort *firewall.Port
dPort *firewall.Port
drop bool
udpHook func([]byte) bool
drop bool
}
// ID returns the rule id

View File

@@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) {
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
return true // drop (intercepted by hook)
})
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted,
},
expectedAllow: false,

View File

@@ -15,14 +15,17 @@ type PacketFilter interface {
// FilterInbound filter incoming packets from external sources to host
FilterInbound(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one UDP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one TCP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
}
// FilteredDevice to override Read or Write of packets

View File

@@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
// SetUDPPacketHook mocks base method.
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)
return ret0
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
}
// SetTCPPacketHook mocks base method.
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
}
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
}
// FilterInbound mocks base method.
@@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
}
// RemovePacketHook mocks base method.
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemovePacketHook indicates an expected call of RemovePacketHook.
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
}

View File

@@ -1,87 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockPacketFilter is a mock of PacketFilter interface.
type MockPacketFilter struct {
ctrl *gomock.Controller
recorder *MockPacketFilterMockRecorder
}
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
type MockPacketFilterMockRecorder struct {
mock *MockPacketFilter
}
// NewMockPacketFilter creates a new mock instance.
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
mock := &MockPacketFilter{ctrl: ctrl}
mock.recorder = &MockPacketFilterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
err := a.doMgmLogin(client, ctx, pubSSHKey)
if isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
protoFlow, err := client.GetPKCEAuthorizationFlow()
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
@@ -221,7 +215,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
protoFlow, err := client.GetDeviceAuthorizationFlow()
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
@@ -266,7 +254,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
return err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err

View File

@@ -44,6 +44,10 @@ import (
"github.com/netbirdio/netbird/version"
)
// androidRunOverride is set on Android to inject mobile dependencies
// when using embed.Client (which calls Run() with empty MobileDependency).
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
@@ -76,6 +80,9 @@ func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
if androidRunOverride != nil {
return androidRunOverride(c, runningChan, logPath)
}
return c.run(MobileDependency{}, runningChan, logPath)
}
@@ -610,12 +617,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
@@ -634,12 +635,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {
return nil, err
}
return loginResp, nil
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
}
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {

View File

@@ -0,0 +1,73 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
// It returns an empty interface list, which means ICE P2P candidates won't be
// discovered — connections will fall back to relay. Applications that need P2P
// should provide a real implementation via runOnAndroidEmbed that uses
// Android's ConnectivityManager to enumerate network interfaces.
type noopIFaceDiscover struct{}
func (noopIFaceDiscover) IFaces() (string, error) {
// Return empty JSON array — no local interfaces advertised for ICE.
// This is intentional: without Android's ConnectivityManager, we cannot
// reliably enumerate interfaces (netlink is restricted on Android 11+).
// Relay connections still work; only P2P hole-punching is disabled.
return "[]", nil
}
// noopNetworkChangeListener is a stub for embed.Client on Android.
// Network change events are ignored since the embed client manages its own
// reconnection logic via the engine's built-in retry mechanism.
type noopNetworkChangeListener struct{}
func (noopNetworkChangeListener) OnNetworkChanged(string) {
// No-op: embed.Client relies on the engine's internal reconnection
// logic rather than OS-level network change notifications.
}
func (noopNetworkChangeListener) SetInterfaceIP(string) {
// No-op: in netstack mode, the overlay IP is managed by the userspace
// network stack, not by OS-level interface configuration.
}
// noopDnsReadyListener is a stub for embed.Client on Android.
// DNS readiness notifications are not needed in netstack/embed mode
// since system DNS is disabled and DNS resolution happens externally.
type noopDnsReadyListener struct{}
func (noopDnsReadyListener) OnReady() {
// No-op: embed.Client does not need DNS readiness notifications.
// System DNS is disabled in netstack mode.
}
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
var _ dns.ReadyListener = noopDnsReadyListener{}
func init() {
// Wire up the default override so embed.Client.Start() works on Android
// with netstack mode. Provides complete no-op stubs for all mobile
// dependencies so the engine's existing Android code paths work unchanged.
// Applications that need P2P ICE or real DNS should replace this by
// setting androidRunOverride before calling Start().
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
return c.runOnAndroidEmbed(
noopIFaceDiscover{},
noopNetworkChangeListener{},
[]netip.AddrPort{},
noopDnsReadyListener{},
runningChan,
logPath,
)
}
}

View File

@@ -0,0 +1,32 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
// so embed.Client.Start() can detect when the engine is ready.
// It provides complete MobileDependency so the engine's existing
// Android code paths work unchanged.
func (c *ConnectClient) runOnAndroidEmbed(
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
runningChan chan struct{},
logPath string,
) error {
mobileDependency := MobileDependency{
IFaceDiscover: iFaceDiscover,
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, runningChan, logPath)
}

View File

@@ -73,6 +73,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
return nil
}
w.response = m
if m.MsgHdr.Truncated {
w.SetMeta("truncated", "true")
}
return w.ResponseWriter.WriteMsg(m)
}
@@ -195,10 +198,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
startTime := time.Now()
requestID := resutil.GenerateRequestID()
logger := log.WithFields(log.Fields{
fields := log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
})
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
question := r.Question[0]
qname := strings.ToLower(question.Name)
@@ -261,9 +268,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
meta, time.Since(startTime))
cw.response.Len(), meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {

View File

@@ -85,6 +85,16 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// Mock implementation - no-op
}
// SetFirewall mock implementation of SetFirewall from Server interface
func (m *MockServer) SetFirewall(Firewall) {
// Mock implementation - no-op
}
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op

View File

@@ -104,3 +104,23 @@ func (r *responseWriter) TsigTimersOnly(bool) {
// After a call to Hijack(), the DNS package will not do anything with the connection.
func (r *responseWriter) Hijack() {
}
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
var srcIP net.IP
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
srcIP = ipv4.(*layers.IPv4).SrcIP
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
srcIP = ipv6.(*layers.IPv6).SrcIP
}
var srcPort int
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
srcPort = int(udp.(*layers.UDP).SrcPort)
}
if srcIP == nil {
return nil
}
return &net.UDPAddr{IP: srcIP, Port: srcPort}
}

View File

@@ -57,6 +57,8 @@ type Server interface {
ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
SetRouteChecker(func(netip.Addr) bool)
SetFirewall(Firewall)
}
type nsGroupsByDomain struct {
@@ -104,6 +106,7 @@ type DefaultServer struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
routeMatch func(netip.Addr) bool
probeMu sync.Mutex
probeCancel context.CancelFunc
@@ -149,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(config.WgInterface)
} else {
dnsService = newServiceViaListener(config.WgInterface, addrPort)
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
}
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
@@ -229,6 +232,14 @@ func newDefaultServer(
return defaultServer
}
// SetRouteChecker sets the function used by upstream resolvers to determine
// whether an IP is routed through the tunnel.
func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) {
s.mux.Lock()
defer s.mux.Unlock()
s.routeMatch = f
}
// RegisterHandler registers a handler for the given domains with the given priority.
// Any previously registered handler for the same domain and priority will be replaced.
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
@@ -364,6 +375,17 @@ func (s *DefaultServer) DnsIP() netip.Addr {
return s.service.RuntimeIP()
}
// SetFirewall sets the firewall used for DNS port DNAT rules.
// This must be called before Initialize when using the listener-based service,
// because the firewall is typically not available at construction time.
func (s *DefaultServer) SetFirewall(fw Firewall) {
if svc, ok := s.service.(*serviceViaListener); ok {
svc.listenerFlagLock.Lock()
svc.firewall = fw
svc.listenerFlagLock.Unlock()
}
}
// Stop stops the server
func (s *DefaultServer) Stop() {
s.probeMu.Lock()
@@ -385,8 +407,12 @@ func (s *DefaultServer) Stop() {
maps.Clear(s.extraDomains)
}
func (s *DefaultServer) disableDNS() error {
defer s.service.Stop()
func (s *DefaultServer) disableDNS() (retErr error) {
defer func() {
if err := s.service.Stop(); err != nil {
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
}
}()
if s.isUsingNoopHostManager() {
return nil
@@ -743,6 +769,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
return
}
handler.routeMatch = s.routeMatch
for _, ns := range originalNameservers {
if ns == config.ServerIP {
@@ -852,6 +879,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
if err != nil {
return nil, fmt.Errorf("create upstream resolver: %v", err)
}
handler.routeMatch = s.routeMatch
for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType {
@@ -1036,6 +1064,7 @@ func (s *DefaultServer) addHostRootZone() {
log.Errorf("unable to create a new upstream resolver, error: %v", err)
return
}
handler.routeMatch = s.routeMatch
handler.upstreamServers = maps.Keys(hostDNSServers)
handler.deactivate = func(error) {}

View File

@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
@@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand
type mockService struct{}
func (m *mockService) Listen() error { return nil }
func (m *mockService) Stop() {}
func (m *mockService) Stop() error { return nil }
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
func (m *mockService) RuntimePort() int { return 53 }
func (m *mockService) RegisterMux(string, dns.Handler) {}

View File

@@ -4,15 +4,25 @@ import (
"net/netip"
"github.com/miekg/dns"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
DefaultPort = 53
)
// Firewall provides DNAT capabilities for DNS port redirection.
// This is used when the DNS server cannot bind port 53 directly
// and needs firewall rules to redirect traffic.
type Firewall interface {
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
}
type service interface {
Listen() error
Stop()
Stop() error
RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string)
RuntimePort() int

View File

@@ -10,9 +10,13 @@ import (
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
@@ -31,25 +35,33 @@ type serviceViaListener struct {
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
tcpServer *dns.Server
listenIP netip.Addr
listenPort uint16
listenerIsRunning bool
listenerFlagLock sync.Mutex
ebpfService ebpfMgr.Manager
firewall Firewall
tcpDNATConfigured bool
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
mux := dns.NewServeMux()
s := &serviceViaListener{
wgInterface: wgIface,
dnsMux: mux,
customAddr: customAddr,
firewall: fw,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
tcpServer: &dns.Server{
Net: "tcp",
Handler: mux,
},
}
return s
@@ -70,43 +82,86 @@ func (s *serviceViaListener) Listen() error {
return fmt.Errorf("eval listen address: %w", err)
}
s.listenIP = s.listenIP.Unmap()
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
s.server.Addr = addr
s.tcpServer.Addr = addr
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
log.Debugf("starting dns on %s (UDP + TCP)", addr)
s.listenerIsRunning = true
go func() {
if err := s.server.ListenAndServe(); err != nil {
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
}
s.listenerFlagLock.Lock()
unexpected := s.listenerIsRunning
s.listenerIsRunning = false
s.listenerFlagLock.Unlock()
if unexpected {
if err := s.tcpServer.Shutdown(); err != nil {
log.Debugf("failed to shutdown DNS TCP server: %v", err)
}
}
}()
go func() {
if err := s.tcpServer.ListenAndServe(); err != nil {
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
}
}()
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
// a DNAT rule because eBPF only handles UDP.
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
} else {
s.tcpDNATConfigured = true
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
}
}
return nil
}
func (s *serviceViaListener) Stop() {
func (s *serviceViaListener) Stop() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
return nil
}
s.listenerIsRunning = false
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
var merr *multierror.Error
if err := s.server.ShutdownContext(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
}
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
}
if s.tcpDNATConfigured && s.firewall != nil {
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
}
s.tcpDNATConfigured = false
}
if s.ebpfService != nil {
err = s.ebpfService.FreeDNSFwd()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
if err := s.ebpfService.FreeDNSFwd(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
@@ -133,12 +188,6 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
return s.listenIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
s.listenerIsRunning = running
}
// evalListenAddress figure out the listen address for the DNS server
// first check the 53 port availability on WG interface or lo, if not success
@@ -187,18 +236,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
}
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
addrPort := netip.AddrPortFrom(ip, uint16(port))
udpAddr := net.UDPAddrFromAddrPort(addrPort)
udpLn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
return false
}
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
if err := udpLn.Close(); err != nil {
log.Debugf("close UDP probe listener: %s", err)
}
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
return false
}
if err := tcpLn.Close(); err != nil {
log.Debugf("close TCP probe listener: %s", err)
}
return true
}

View File

@@ -0,0 +1,86 @@
package dns
import (
"fmt"
"net"
"net/netip"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("192.0.2.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
// Create a service using a custom address to avoid needing root
svc := newServiceViaListener(nil, nil, nil)
svc.dnsMux.Handle(".", handler)
// Bind both transports up front to avoid TOCTOU races.
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
t.Skip("cannot bind to 127.0.0.153, skipping")
}
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
udpConn.Close()
t.Skip("cannot bind TCP on same port, skipping")
}
addr := fmt.Sprintf("%s:%d", customIP, port)
svc.server.PacketConn = udpConn
svc.tcpServer.Listener = tcpLn
svc.listenIP = customIP
svc.listenPort = port
go func() {
if err := svc.server.ActivateAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := svc.tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
svc.listenerIsRunning = true
defer func() {
require.NoError(t, svc.Stop())
}()
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
// Test UDP query
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
udpResp, _, err := udpClient.Exchange(q, addr)
require.NoError(t, err, "UDP query should succeed")
require.NotNil(t, udpResp)
require.NotEmpty(t, udpResp.Answer)
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
// Test TCP query
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
tcpResp, _, err := tcpClient.Exchange(q, addr)
require.NoError(t, err, "TCP query should succeed")
require.NotNil(t, tcpResp)
require.NotEmpty(t, tcpResp.Answer)
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
}

View File

@@ -1,6 +1,7 @@
package dns
import (
"errors"
"fmt"
"net/netip"
"sync"
@@ -10,6 +11,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net"
)
@@ -18,7 +20,8 @@ type ServiceViaMemory struct {
dnsMux *dns.ServeMux
runtimeIP netip.Addr
runtimePort int
udpFilterHookID string
tcpDNS *tcpDNSServer
tcpHookSet bool
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
@@ -28,14 +31,13 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
s := &ServiceViaMemory{
return &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: lastIP,
runtimePort: DefaultPort,
}
return s
}
func (s *ServiceViaMemory) Listen() error {
@@ -46,10 +48,8 @@ func (s *ServiceViaMemory) Listen() error {
return nil
}
var err error
s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil {
return fmt.Errorf("filter dns traffice: %w", err)
if err := s.filterDNSTraffic(); err != nil {
return fmt.Errorf("filter dns traffic: %w", err)
}
s.listenerIsRunning = true
@@ -57,19 +57,29 @@ func (s *ServiceViaMemory) Listen() error {
return nil
}
func (s *ServiceViaMemory) Stop() {
func (s *ServiceViaMemory) Stop() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
return nil
}
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
filter := s.wgInterface.GetFilter()
if filter != nil {
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
if s.tcpHookSet {
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
}
}
if s.tcpDNS != nil {
s.tcpDNS.Stop()
}
s.listenerIsRunning = false
return nil
}
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
@@ -88,10 +98,18 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
return s.runtimeIP
}
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
func (s *ServiceViaMemory) filterDNSTraffic() error {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
return errors.New("DNS filter not initialized")
}
// Create TCP DNS server lazily here since the device may not exist at construction time.
if s.tcpDNS == nil {
if dev := s.wgInterface.GetDevice(); dev != nil {
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
}
}
firstLayerDecoder := layers.LayerTypeIPv4
@@ -100,12 +118,16 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
if udpLayer == nil {
return true
}
udp, ok := udpLayer.(*layers.UDP)
if !ok {
return true
}
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
@@ -113,13 +135,30 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
dev := s.wgInterface.GetDevice()
if dev == nil {
return true
}
go s.dnsMux.ServeDNS(&writer, msg)
writer := &responseWriter{
remote: remoteAddrFromPacket(packet),
packet: packet,
device: dev.Device,
}
go s.dnsMux.ServeDNS(writer, msg)
return true
}
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
if s.tcpDNS != nil {
tcpHook := func(packetData []byte) bool {
s.tcpDNS.InjectPacket(packetData)
return true
}
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
s.tcpHookSet = true
}
return nil
}

View File

@@ -0,0 +1,444 @@
package dns
import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
dnsTCPReceiveWindow = 8192
dnsTCPMaxInFlight = 16
dnsTCPIdleTimeout = 30 * time.Second
dnsTCPReadTimeout = 5 * time.Second
)
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
// It is started lazily when a truncated DNS response is detected and shuts down
// after a period of inactivity to conserve resources.
type tcpDNSServer struct {
mu sync.Mutex
s *stack.Stack
ep *dnsEndpoint
mux *dns.ServeMux
tunDev tun.Device
ip netip.Addr
port uint16
mtu uint16
running bool
closed bool
timerID uint64
timer *time.Timer
}
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
return &tcpDNSServer{
mux: mux,
tunDev: tunDev,
ip: ip,
port: port,
mtu: mtu,
}
}
// InjectPacket ensures the stack is running and delivers a raw IP packet into
// the gvisor stack for TCP processing. Combining both operations under a single
// lock prevents a race where the idle timer could stop the stack between
// start and delivery.
func (t *tcpDNSServer) InjectPacket(payload []byte) {
t.mu.Lock()
defer t.mu.Unlock()
if t.closed {
return
}
if !t.running {
if err := t.startLocked(); err != nil {
log.Errorf("failed to start TCP DNS stack: %v", err)
return
}
t.running = true
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
}
t.resetTimerLocked()
ep := t.ep
if ep == nil || ep.dispatcher == nil {
return
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
// Stop tears down the gvisor stack and releases resources permanently.
// After Stop, InjectPacket becomes a no-op.
func (t *tcpDNSServer) Stop() {
t.mu.Lock()
defer t.mu.Unlock()
t.stopLocked()
t.closed = true
}
func (t *tcpDNSServer) startLocked() error {
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
HandleLocal: false,
})
nicID := tcpip.NICID(1)
ep := &dnsEndpoint{
tunDev: t.tunDev,
}
ep.mtu.Store(uint32(t.mtu))
if err := s.CreateNIC(nicID, ep); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("create NIC: %v", err)
}
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
PrefixLen: 32,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("add protocol address: %s", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("set promiscuous mode: %s", err)
}
if err := s.SetSpoofing(nicID, true); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("set spoofing: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
s.Close()
s.Wait()
return fmt.Errorf("create default subnet: %w", err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: defaultSubnet, NIC: nicID},
})
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
t.handleTCPDNS(r)
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
t.s = s
t.ep = ep
return nil
}
func (t *tcpDNSServer) stopLocked() {
if !t.running {
return
}
if t.timer != nil {
t.timer.Stop()
t.timer = nil
}
if t.s != nil {
t.s.Close()
t.s.Wait()
t.s = nil
}
t.ep = nil
t.running = false
log.Debugf("TCP DNS stack stopped")
}
func (t *tcpDNSServer) resetTimerLocked() {
if t.timer != nil {
t.timer.Stop()
}
t.timerID++
id := t.timerID
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
t.mu.Lock()
defer t.mu.Unlock()
// Only stop if this timer is still the active one.
// A racing InjectPacket may have replaced it.
if t.timerID != id {
return
}
t.stopLocked()
})
}
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
id := r.ID()
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
r.Complete(true)
return
}
r.Complete(false)
conn := gonet.NewTCPConn(&wq, ep)
defer func() {
if err := conn.Close(); err != nil {
log.Tracef("TCP DNS: close conn: %v", err)
}
}()
// Reset idle timer on activity
t.mu.Lock()
t.resetTimerLocked()
t.mu.Unlock()
localAddr := &net.TCPAddr{
IP: id.LocalAddress.AsSlice(),
Port: int(id.LocalPort),
}
remoteAddr := &net.TCPAddr{
IP: id.RemoteAddress.AsSlice(),
Port: int(id.RemotePort),
}
for {
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
break
}
msg, err := readTCPDNSMessage(conn)
if err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
}
break
}
writer := &tcpResponseWriter{
conn: conn,
localAddr: localAddr,
remoteAddr: remoteAddr,
}
t.mux.ServeDNS(writer, msg)
}
}
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
type dnsEndpoint struct {
dispatcher stack.NetworkDispatcher
tunDev tun.Device
mtu atomic.Uint32
}
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
func (e *dnsEndpoint) Wait() { /* no async work */ }
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
const tunPacketOffset = 40
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
raw := data.AsSlice()
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
buf = append(buf, raw...)
data.Release()
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
continue
}
written++
}
return written, nil
}
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
type tcpResponseWriter struct {
conn *gonet.TCPConn
localAddr net.Addr
remoteAddr net.Addr
}
func (w *tcpResponseWriter) LocalAddr() net.Addr {
return w.localAddr
}
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
return w.remoteAddr
}
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
data, err := msg.Pack()
if err != nil {
return fmt.Errorf("pack: %w", err)
}
// DNS TCP: 2-byte length prefix + message
buf := make([]byte, 2+len(data))
buf[0] = byte(len(data) >> 8)
buf[1] = byte(len(data))
copy(buf[2:], data)
if _, err = w.conn.Write(buf); err != nil {
return err
}
return nil
}
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
buf := make([]byte, 2+len(data))
buf[0] = byte(len(data) >> 8)
buf[1] = byte(len(data))
copy(buf[2:], data)
if _, err := w.conn.Write(buf); err != nil {
return 0, err
}
return len(data), nil
}
func (w *tcpResponseWriter) Close() error {
return w.conn.Close()
}
func (w *tcpResponseWriter) TsigStatus() error { return nil }
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
// DNS over TCP uses a 2-byte length prefix
lenBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return nil, fmt.Errorf("read length: %w", err)
}
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
if msgLen == 0 || msgLen > 65535 {
return nil, fmt.Errorf("invalid message length: %d", msgLen)
}
msgBuf := make([]byte, msgLen)
if _, err := io.ReadFull(conn, msgBuf); err != nil {
return nil, fmt.Errorf("read message: %w", err)
}
msg := new(dns.Msg)
if err := msg.Unpack(msgBuf); err != nil {
return nil, fmt.Errorf("unpack: %w", err)
}
return msg, nil
}
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
// Supports both IPv4 and IPv6.
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
if len(pkt) == 0 {
return netip.AddrPort{}
}
srcIP, transportOffset := srcIPFromPacket(pkt)
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
return netip.AddrPort{}
}
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
}
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
switch header.IPVersion(pkt) {
case 4:
return srcIPv4(pkt)
case 6:
return srcIPv6(pkt)
default:
return netip.Addr{}, 0
}
}
func srcIPv4(pkt []byte) (netip.Addr, int) {
if len(pkt) < header.IPv4MinimumSize {
return netip.Addr{}, 0
}
hdr := header.IPv4(pkt)
src := hdr.SourceAddress()
ip, ok := netip.AddrFromSlice(src.AsSlice())
if !ok {
return netip.Addr{}, 0
}
return ip, int(hdr.HeaderLength())
}
func srcIPv6(pkt []byte) (netip.Addr, int) {
if len(pkt) < header.IPv6MinimumSize {
return netip.Addr{}, 0
}
hdr := header.IPv6(pkt)
src := hdr.SourceAddress()
ip, ok := netip.AddrFromSlice(src.AsSlice())
if !ok {
return netip.Addr{}, 0
}
return ip, header.IPv6MinimumSize
}

View File

@@ -41,10 +41,61 @@ const (
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
// payload from the tunnel MTU.
ipUDPHeaderSize = 60 + 8
)
const testRecord = "com."
const (
protoUDP = "udp"
protoTCP = "tcp"
)
type dnsProtocolKey struct{}
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
return context.WithValue(ctx, dnsProtocolKey{}, network)
}
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
func dnsProtocolFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
return v
}
return ""
}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
r := &upstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
func setUpstreamProtocol(ctx context.Context, protocol string) {
if ctx == nil {
return
}
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
r.protocol = protocol
}
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
@@ -70,6 +121,7 @@ type upstreamResolverBase struct {
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
routeMatch func(netip.Addr) bool
}
type upstreamFailure struct {
@@ -137,7 +189,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
ok, failures := u.tryUpstreamServers(w, r, logger)
// Propagate inbound protocol so upstream exchange can use TCP directly
// when the request came in over TCP.
ctx := u.ctx
if addr := w.RemoteAddr(); addr != nil {
network := addr.Network()
ctx = contextWithDNSProtocol(ctx, network)
resutil.SetMeta(w, "protocol", network)
}
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
if len(failures) > 0 {
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
}
@@ -152,7 +213,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
}
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
@@ -167,7 +228,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
@@ -177,15 +238,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
@@ -202,7 +265,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
}
@@ -219,10 +282,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
u.successCount.Add(1)
resutil.SetMeta(w, "upstream", upstream.String())
if upstreamProto != nil && upstreamProto.protocol != "" {
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
}
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
@@ -427,13 +493,42 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC
return err
}
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
func clientUDPMaxSize(r *dns.Msg) int {
if opt := r.IsEdns0(); opt != nil {
return int(opt.UDPSize())
}
return dns.MinMsgSize
}
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// MTU - ip + udp headers
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
client.UDPSize = uint16(currentMTU - (60 + 8))
// If the request came in over TCP, go straight to TCP upstream.
if dnsProtocolFromContext(ctx) == protoTCP {
tcpClient := *client
tcpClient.Net = protoTCP
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
setUpstreamProtocol(ctx, protoTCP)
return rm, t, nil
}
clientMaxSize := clientUDPMaxSize(r)
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
// response larger than our read buffer.
// Note: the query could be sent out on an interface that is not ours,
// but higher MTU settings could break truncation handling.
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
client.UDPSize = maxUDPPayload
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
opt.SetUDPSize(maxUDPPayload)
}
var (
rm *dns.Msg
@@ -452,25 +547,32 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
}
if rm == nil || !rm.MsgHdr.Truncated {
setUpstreamProtocol(ctx, protoUDP)
return rm, t, nil
}
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// TODO: if the upstream's truncated UDP response already contains more
// data than the client's buffer, we could truncate locally and skip
// the TCP retry.
client.Net = "tcp"
tcpClient := *client
tcpClient.Net = protoTCP
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
rm, t, err = tcpClient.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
setUpstreamProtocol(ctx, protoTCP)
if rm.Len() > clientMaxSize {
rm.Truncate(clientMaxSize)
}
return rm, t, nil
}
@@ -478,18 +580,46 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
// If request came in over TCP, go straight to TCP upstream
if dnsProtocolFromContext(ctx) == protoTCP {
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, protoTCP)
return rm, nil
}
clientMaxSize := clientUDPMaxSize(r)
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
// response larger than what we can read over UDP.
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
opt.SetUDPSize(maxUDPPayload)
}
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
if err != nil {
return nil, err
}
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, protoTCP)
if rm.Len() > clientMaxSize {
rm.Truncate(clientMaxSize)
}
return rm, nil
}
setUpstreamProtocol(ctx, protoUDP)
return reply, nil
}
@@ -510,7 +640,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
}
}
dnsConn := &dns.Conn{Conn: conn}
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)

View File

@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
upstreamExchangeClient := &dns.Client{
Timeout: ClientTimeout,
}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
}
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
Timeout: timeout,
}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
}
func (u *upstreamResolver) isLocalResolver(upstream string) bool {

View File

@@ -65,11 +65,13 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} else {
upstreamIP = upstreamIP.Unmap()
}
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream)
needsPrivate := u.lNet.Contains(upstreamIP) ||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
if needsPrivate {
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
return nil, 0, fmt.Errorf("create private client: %s", err)
}
}

View File

@@ -475,3 +475,298 @@ func TestFormatFailures(t *testing.T) {
})
}
}
func TestDNSProtocolContext(t *testing.T) {
t.Run("roundtrip udp", func(t *testing.T) {
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
})
t.Run("roundtrip tcp", func(t *testing.T) {
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
})
t.Run("missing returns empty", func(t *testing.T) {
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
})
}
func TestExchangeWithFallback_TCPContext(t *testing.T) {
// Start a local DNS server that responds on TCP only
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpServer := &dns.Server{
Addr: "127.0.0.1:0",
Net: "tcp",
Handler: tcpHandler,
}
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer.Listener = tcpLn
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// With TCP context, should connect directly via TCP without trying UDP
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
}
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
// UDP handler returns a truncated response to trigger TCP retry.
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = true
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
// TCP handler returns the full answer.
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.3"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{
PacketConn: udpPC,
Net: "udp",
Handler: udpHandler,
}
tcpLn, err := net.Listen("tcp", addr)
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: tcpHandler,
}
go func() {
if err := udpServer.ActivateAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = udpServer.Shutdown()
_ = tcpServer.Shutdown()
}()
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
assert.False(t, rm.Truncated, "TCP response should not be truncated")
}
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
// Start only a TCP server (no UDP). With TCP context it should succeed.
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.2"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: tcpHandler,
}
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// TCP context: should skip UDP entirely and go directly to TCP
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
// Without TCP context, trying to reach a TCP-only server via UDP should fail
ctx2 := context.Background()
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
assert.Error(t, err, "should fail when no UDP server and no TCP context")
}
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
// Verify that a client EDNS0 larger than our MTU-derived limit gets
// capped in the outgoing request so the upstream doesn't send a
// response larger than our read buffer.
var receivedUDPSize uint16
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
if opt := r.IsEdns0(); opt != nil {
receivedUDPSize = opt.UDPSize()
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
go func() { _ = udpServer.ActivateAndServe() }()
t.Cleanup(func() { _ = udpServer.Shutdown() })
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
r.SetEdns0(4096, false)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err)
require.NotNil(t, rm)
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
assert.Equal(t, expectedMax, receivedUDPSize,
"upstream should see capped EDNS0, not the client's 4096")
}
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
// When the client advertises a large EDNS0 (4096) and the upstream
// truncates, the TCP response should NOT be truncated since the full
// answer fits within the client's original buffer.
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = true
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
// Add enough records to exceed MTU but fit within 4096
for i := range 20 {
m.Answer = append(m.Answer, &dns.TXT{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
})
}
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
tcpLn, err := net.Listen("tcp", addr)
require.NoError(t, err)
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
go func() { _ = udpServer.ActivateAndServe() }()
go func() { _ = tcpServer.ActivateAndServe() }()
t.Cleanup(func() {
_ = udpServer.Shutdown()
_ = tcpServer.Shutdown()
})
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
// Client with large buffer: should get all records without truncation
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r.SetEdns0(4096, false)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err)
require.NotNil(t, rm)
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
// Client with small buffer: should get truncated response
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r2.SetEdns0(512, false)
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
require.NoError(t, err)
require.NotNil(t, rm2)
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
}

View File

@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
@@ -263,20 +263,28 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
logger := log.WithFields(log.Fields{
fields := log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
logger := log.WithFields(log.Fields{
fields := log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
f.handleDNSQuery(logger, w, query, startTime)
}

View File

@@ -499,6 +499,17 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
}
}
}
return false
})
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()
@@ -510,6 +521,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return err
}
// Inject firewall into DNS server now that it's available.
// The DNS server is created before the firewall because the route manager
// depends on the DNS server, and the firewall depends on the wg interface.
e.dnsServer.SetFirewall(e.firewall)
e.udpMux, err = e.wgInterface.Up()
if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
@@ -1826,6 +1842,11 @@ func (e *Engine) GetExposeManager() *expose.Manager {
return e.exposeManager
}
// IsBlockInbound returns whether inbound connections are blocked.
func (e *Engine) IsBlockInbound() bool {
return e.config.BlockInbound
}
// GetClientMetrics returns the client metrics
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics

View File

@@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
@@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
@@ -1538,13 +1538,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
return nil, err
}
publicKey, err := mgmtClient.GetServerPublicKey()
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
@@ -1566,7 +1561,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,

View File

@@ -4,11 +4,14 @@ import (
"context"
"time"
mgm "github.com/netbirdio/netbird/shared/management/client"
log "github.com/sirupsen/logrus"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
const renewTimeout = 10 * time.Second
const (
renewTimeout = 10 * time.Second
)
// Response holds the response from exposing a service.
type Response struct {
@@ -18,11 +21,13 @@ type Response struct {
PortAutoAssigned bool
}
// Request holds the parameters for exposing a local service via the management server.
// It is part of the embed API surface and exposed via a type alias.
type Request struct {
NamePrefix string
Domain string
Port uint16
Protocol int
Protocol ProtocolType
Pin string
Password string
UserGroups []string
@@ -59,6 +64,8 @@ func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
return fromClientExposeResponse(resp), nil
}
// KeepAlive periodically renews the expose session for the given domain until the context is canceled or an error occurs.
// It is part of the embed API surface and exposed via a type alias.
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()

View File

@@ -86,7 +86,7 @@ func TestNewRequest(t *testing.T) {
exposeReq := NewRequest(req)
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
assert.Equal(t, ProtocolType(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
assert.Equal(t, "secret", exposeReq.Password, "password should match")
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")

View File

@@ -0,0 +1,40 @@
package expose
import (
"fmt"
"strings"
)
// ProtocolType represents the protocol used for exposing a service.
type ProtocolType int
const (
// ProtocolHTTP exposes the service as HTTP.
ProtocolHTTP ProtocolType = 0
// ProtocolHTTPS exposes the service as HTTPS.
ProtocolHTTPS ProtocolType = 1
// ProtocolTCP exposes the service as TCP.
ProtocolTCP ProtocolType = 2
// ProtocolUDP exposes the service as UDP.
ProtocolUDP ProtocolType = 3
// ProtocolTLS exposes the service as TLS.
ProtocolTLS ProtocolType = 4
)
// ParseProtocolType parses a protocol string into a ProtocolType.
func ParseProtocolType(s string) (ProtocolType, error) {
switch strings.ToLower(s) {
case "http":
return ProtocolHTTP, nil
case "https":
return ProtocolHTTPS, nil
case "tcp":
return ProtocolTCP, nil
case "udp":
return ProtocolUDP, nil
case "tls":
return ProtocolTLS, nil
default:
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", s)
}
}

View File

@@ -9,7 +9,7 @@ import (
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
return &Request{
Port: uint16(req.Port),
Protocol: int(req.Protocol),
Protocol: ProtocolType(req.Protocol),
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
@@ -24,7 +24,7 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest {
NamePrefix: req.NamePrefix,
Domain: req.Domain,
Port: req.Port,
Protocol: req.Protocol,
Protocol: int(req.Protocol),
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,

View File

@@ -39,6 +39,18 @@ const (
DefaultAdminURL = "https://app.netbird.io:443"
)
// mgmProber is the subset of management client needed for URL migration probes.
type mgmProber interface {
HealthCheck() error
Close() error
}
// newMgmProber creates a management client for probing URL reachability.
// Overridden in tests to avoid real network calls.
var newMgmProber = func(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (mgmProber, error) {
return mgm.NewClient(ctx, addr, key, tlsEnabled)
}
var DefaultInterfaceBlacklist = []string{
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
@@ -753,21 +765,19 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
return config, err
}
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
client, err := newMgmProber(ctx, newURL.Host, key, mgmTlsEnabled)
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, err
}
defer func() {
err = client.Close()
if err != nil {
if err := client.Close(); err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
// gRPC check
_, err = client.GetServerPublicKey()
if err != nil {
if err = client.HealthCheck(); err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return nil, err
}

View File

@@ -10,12 +10,21 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/util"
)
type mockMgmProber struct{}
func (m *mockMgmProber) HealthCheck() error {
return nil
}
func (m *mockMgmProber) Close() error { return nil }
func TestGetConfig(t *testing.T) {
// case 1: new default config has to be generated
config, err := UpdateOrCreateConfig(ConfigInput{
@@ -234,6 +243,12 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
}
func TestUpdateOldManagementURL(t *testing.T) {
origProber := newMgmProber
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
return &mockMgmProber{}, nil
}
t.Cleanup(func() { newMgmProber = origProber })
tests := []struct {
name string
previousManagementURL string
@@ -273,18 +288,17 @@ func TestUpdateOldManagementURL(t *testing.T) {
ConfigPath: configPath,
})
require.NoError(t, err, "failed to create testing config")
previousStats, err := os.Stat(configPath)
require.NoError(t, err, "failed to create testing config stats")
previousContent, err := os.ReadFile(configPath)
require.NoError(t, err, "failed to read initial config")
resultConfig, err := UpdateOldManagementURL(context.TODO(), config, configPath)
require.NoError(t, err, "got error when updating old management url")
require.Equal(t, tt.expectedManagementURL, resultConfig.ManagementURL.String())
newStats, err := os.Stat(configPath)
require.NoError(t, err, "failed to create testing config stats")
switch tt.fileShouldNotChange {
case true:
require.Equal(t, previousStats.ModTime(), newStats.ModTime(), "file should not change")
case false:
require.NotEqual(t, previousStats.ModTime(), newStats.ModTime(), "file should have changed")
newContent, err := os.ReadFile(configPath)
require.NoError(t, err, "failed to read updated config")
if tt.fileShouldNotChange {
require.Equal(t, string(previousContent), string(newContent), "file should not change")
} else {
require.NotEqual(t, string(previousContent), string(newContent), "file should have changed")
}
})
}

View File

@@ -52,6 +52,7 @@ type Manager interface {
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetSelectedClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
@@ -465,6 +466,16 @@ func (m *DefaultManager) GetClientRoutes() route.HAMap {
return maps.Clone(m.clientRoutes)
}
// GetSelectedClientRoutes returns only the currently selected/active client routes,
// filtering out deselected exit nodes. Use this instead of GetClientRoutes when checking
// if traffic should be routed through the tunnel.
func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
m.mux.Lock()
defer m.mux.Unlock()
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()

View File

@@ -18,6 +18,7 @@ type MockManager struct {
TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetSelectedClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager)
}
@@ -61,7 +62,7 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
return nil
}
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
// GetClientRoutes mock implementation of GetClientRoutes from the Manager interface
func (m *MockManager) GetClientRoutes() route.HAMap {
if m.GetClientRoutesFunc != nil {
return m.GetClientRoutesFunc()
@@ -69,6 +70,14 @@ func (m *MockManager) GetClientRoutes() route.HAMap {
return nil
}
// GetSelectedClientRoutes mock implementation of GetSelectedClientRoutes from the Manager interface
func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
if m.GetSelectedClientRoutesFunc != nil {
return m.GetSelectedClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {

View File

@@ -31,26 +31,11 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listener = listener
}
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
// and a separate comparison set (without fake IP blocks) for diff detection.
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
// initialRoutes contains fake IP block for interface configuration
filteredInitial := make([]*route.Route, 0)
for _, r := range initialRoutes {
if r.IsDynamic() {
continue
}
filteredInitial = append(filteredInitial, r)
}
n.initialRoutes = filteredInitial
// routesForComparison excludes fake IP block for comparison with new routes
filteredComparison := make([]*route.Route, 0)
for _, r := range routesForComparison {
if r.IsDynamic() {
continue
}
filteredComparison = append(filteredComparison, r)
}
n.currentRoutes = filteredComparison
n.initialRoutes = filterStatic(initialRoutes)
n.currentRoutes = filterStatic(routesForComparison)
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
@@ -83,13 +68,43 @@ func (n *Notifier) notify() {
return
}
routeStrings := n.routesToStrings(n.currentRoutes)
allRoutes := slices.Clone(n.currentRoutes)
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
routeStrings := n.routesToStrings(allRoutes)
sort.Strings(routeStrings)
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ","))
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, allRoutes), ","))
}(n.listener)
}
// extraInitialRoutes returns initialRoutes whose network prefix is absent
// from currentRoutes (e.g. the fake IP block added at setup time).
func (n *Notifier) extraInitialRoutes() []*route.Route {
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
for _, r := range n.currentRoutes {
currentNets[r.Network] = struct{}{}
}
var extra []*route.Route
for _, r := range n.initialRoutes {
if _, ok := currentNets[r.Network]; !ok {
extra = append(extra, r)
}
}
return extra
}
func filterStatic(routes []*route.Route) []*route.Route {
out := make([]*route.Route, 0, len(routes))
for _, r := range routes {
if !r.IsDynamic() {
out = append(out, r)
}
}
return out
}
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
nets := make([]string, 0, len(routes))
for _, r := range routes {

View File

@@ -1359,6 +1359,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
}
if engine.IsBlockInbound() {
return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first")
}
mgr := engine.GetExposeManager()
if mgr == nil {
return gstatus.Errorf(codes.Internal, "expose manager not available")

View File

@@ -284,19 +284,21 @@ func (s *Server) closeListener(ln net.Listener) {
// Stop closes the SSH server
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.sshServer == nil {
sshServer := s.sshServer
if sshServer == nil {
s.mu.Unlock()
return nil
}
s.sshServer = nil
s.listener = nil
s.mu.Unlock()
if err := s.sshServer.Close(); err != nil {
// Close outside the lock: session handlers need s.mu for unregisterSession.
if err := sshServer.Close(); err != nil {
log.Debugf("close SSH server: %v", err)
}
s.sshServer = nil
s.listener = nil
s.mu.Lock()
maps.Clear(s.sessions)
maps.Clear(s.pendingAuthJWT)
maps.Clear(s.connections)
@@ -307,6 +309,7 @@ func (s *Server) Stop() error {
}
}
maps.Clear(s.remoteForwardListeners)
s.mu.Unlock()
return nil
}

View File

@@ -153,6 +153,9 @@ func networkAddresses() ([]NetworkAddress, error) {
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}

View File

@@ -324,6 +324,7 @@ type serviceClient struct {
exitNodeMu sync.Mutex
mExitNodeItems []menuHandler
exitNodeRetryCancel context.CancelFunc
mExitNodeSeparator *systray.MenuItem
mExitNodeDeselectAll *systray.MenuItem
logFile string
wLoginURL fyne.Window

View File

@@ -421,6 +421,10 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
node.Remove()
}
s.mExitNodeItems = nil
if s.mExitNodeSeparator != nil {
s.mExitNodeSeparator.Remove()
s.mExitNodeSeparator = nil
}
if s.mExitNodeDeselectAll != nil {
s.mExitNodeDeselectAll.Remove()
s.mExitNodeDeselectAll = nil
@@ -453,31 +457,37 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
}
if showDeselectAll {
s.mExitNode.AddSeparator()
deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All")
s.mExitNodeDeselectAll = deselectAllItem
go func() {
for {
_, ok := <-deselectAllItem.ClickedCh
if !ok {
// channel closed: exit the goroutine
return
}
exitNodes, err := s.handleExitNodeMenuDeselectAll()
if err != nil {
log.Warnf("failed to handle deselect all exit nodes: %v", err)
} else {
s.exitNodeMu.Lock()
s.recreateExitNodeMenu(exitNodes)
s.exitNodeMu.Unlock()
}
}
}()
s.addExitNodeDeselectAll()
}
}
func (s *serviceClient) addExitNodeDeselectAll() {
sep := s.mExitNode.AddSubMenuItem("───────────────", "")
sep.Disable()
s.mExitNodeSeparator = sep
deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All")
s.mExitNodeDeselectAll = deselectAllItem
go func() {
for {
_, ok := <-deselectAllItem.ClickedCh
if !ok {
return
}
exitNodes, err := s.handleExitNodeMenuDeselectAll()
if err != nil {
log.Warnf("failed to handle deselect all exit nodes: %v", err)
} else {
s.exitNodeMu.Lock()
s.recreateExitNodeMenu(exitNodes)
s.exitNodeMu.Unlock()
}
}
}()
}
func (s *serviceClient) getExitNodes(conn proto.DaemonServiceClient) ([]*proto.Network, error) {
ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout)
defer cancel()

45
go.mod
View File

@@ -17,13 +17,13 @@ require (
github.com/spf13/cobra v1.10.1
github.com/spf13/pflag v1.0.9
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.46.0
golang.org/x/sys v0.39.0
golang.org/x/crypto v0.48.0
golang.org/x/sys v0.41.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.77.0
google.golang.org/protobuf v1.36.10
google.golang.org/grpc v1.79.3
google.golang.org/protobuf v1.36.11
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
@@ -49,6 +49,7 @@ require (
github.com/eko/gocache/store/redis/v4 v4.2.2
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.3
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/golang/mock v1.6.0
@@ -101,21 +102,21 @@ require (
github.com/vmihailenco/msgpack/v5 v5.4.1
github.com/yusufpapurcu/wmi v1.2.4
github.com/zcalusic/sysinfo v1.1.3
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0
go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.38.0
go.opentelemetry.io/otel/sdk/metric v1.38.0
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0
go.opentelemetry.io/otel v1.42.0
go.opentelemetry.io/otel/exporters/prometheus v0.64.0
go.opentelemetry.io/otel/metric v1.42.0
go.opentelemetry.io/otel/sdk/metric v1.42.0
go.uber.org/mock v0.5.2
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
golang.org/x/mod v0.30.0
golang.org/x/net v0.47.0
golang.org/x/mod v0.32.0
golang.org/x/net v0.51.0
golang.org/x/oauth2 v0.34.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.38.0
golang.org/x/term v0.40.0
golang.org/x/time v0.14.0
google.golang.org/api v0.257.0
gopkg.in/yaml.v3 v3.0.1
@@ -181,7 +182,6 @@ require (
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/go-ldap/ldap/v3 v3.4.12 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
@@ -249,12 +249,13 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/otlptranslator v1.0.0 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/russellhaering/goxmldsig v1.5.0 // indirect
github.com/rymdport/portal v0.4.2 // indirect
github.com/shirou/gopsutil/v4 v4.25.1 // indirect
github.com/shoenig/go-m1cpu v0.2.0 // indirect
github.com/shoenig/go-m1cpu v0.2.1 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/spf13/cast v1.7.0 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
@@ -269,15 +270,15 @@ require (
github.com/zeebo/blake3 v0.2.3 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
go.opentelemetry.io/otel/sdk v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
go.opentelemetry.io/otel/sdk v1.42.0 // indirect
go.opentelemetry.io/otel/trace v1.42.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/image v0.33.0 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/tools v0.39.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.41.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
)

90
go.sum
View File

@@ -487,10 +487,12 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEoIwkU+A6qos=
github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM=
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk=
github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
@@ -511,8 +513,8 @@ github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKd
github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs=
github.com/shirou/gopsutil/v4 v4.25.1/go.mod h1:RoUCUpndaJFtT+2zsZzzmhvbfGoDCJ7nFXKJf8GqJbI=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/go-m1cpu v0.2.0 h1:t4GNqvPZ84Vjtpboo/kT3pIkbaK3vc+JIlD/Wz1zSFY=
github.com/shoenig/go-m1cpu v0.2.0/go.mod h1:KkDOw6m3ZJQAPHbrzkZki4hnx+pDRR1Lo+ldA56wD5w=
github.com/shoenig/go-m1cpu v0.2.1 h1:yqRB4fvOge2+FyRXFkXqsyMoqPazv14Yyy+iyccT2E4=
github.com/shoenig/go-m1cpu v0.2.1/go.mod h1:KkDOw6m3ZJQAPHbrzkZki4hnx+pDRR1Lo+ldA56wD5w=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/shoenig/test v1.7.0 h1:eWcHtTXa6QLnBvm0jgEabMRN/uJ4DMV3M8xUGgRkZmk=
github.com/shoenig/test v1.7.0/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI=
@@ -603,26 +605,26 @@ github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo=
github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.opentelemetry.io/otel/exporters/prometheus v0.64.0 h1:g0LRDXMX/G1SEZtK8zl8Chm4K6GBwRkjPKE36LxiTYs=
go.opentelemetry.io/otel/exporters/prometheus v0.64.0/go.mod h1:UrgcjnarfdlBDP3GjDIJWe6HTprwSazNjwsI+Ru6hro=
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
@@ -633,8 +635,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -648,8 +650,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
@@ -666,8 +668,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
@@ -686,8 +688,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
@@ -738,8 +740,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -752,8 +754,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -765,8 +767,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -780,8 +782,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -799,12 +801,12 @@ google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s=
google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4=
google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 h1:Wgl1rcDNThT+Zn47YyCXOXyX/COgMTIdhJ717F0l4xk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
@@ -815,8 +817,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=

View File

@@ -170,20 +170,66 @@ type Connector struct {
}
// ToStorageConnector converts a Connector to storage.Connector type.
// It maps custom connector types (e.g., "zitadel", "entra") to Dex-native types
// and augments the config with OIDC defaults when needed.
func (c *Connector) ToStorageConnector() (storage.Connector, error) {
data, err := json.Marshal(c.Config)
dexType, augmentedConfig := mapConnectorToDex(c.Type, c.Config)
data, err := json.Marshal(augmentedConfig)
if err != nil {
return storage.Connector{}, fmt.Errorf("failed to marshal connector config: %v", err)
}
return storage.Connector{
ID: c.ID,
Type: c.Type,
Type: dexType,
Name: c.Name,
Config: data,
}, nil
}
// mapConnectorToDex maps custom connector types to Dex-native types and applies
// OIDC defaults. This ensures static connectors from config files or env vars
// are stored with types that Dex can open.
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
switch connType {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
return "oidc", applyOIDCDefaults(connType, config)
default:
return connType, config
}
}
// applyOIDCDefaults clones the config map, sets common OIDC defaults,
// and applies provider-specific overrides.
func applyOIDCDefaults(connType string, config map[string]interface{}) map[string]interface{} {
augmented := make(map[string]interface{}, len(config)+4)
for k, v := range config {
augmented[k] = v
}
setDefault(augmented, "scopes", []string{"openid", "profile", "email"})
setDefault(augmented, "insecureEnableGroups", true)
setDefault(augmented, "insecureSkipEmailVerified", true)
switch connType {
case "zitadel":
setDefault(augmented, "getUserInfo", true)
case "entra":
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
case "okta", "pocketid":
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
}
return augmented
}
// setDefault sets a key in the map only if it doesn't already exist.
func setDefault(m map[string]interface{}, key string, value interface{}) {
if _, ok := m[key]; !ok {
m[key] = value
}
}
// StorageConfig is a configuration that can create a storage.
type StorageConfig interface {
Open(logger *slog.Logger) (storage.Storage, error)

View File

@@ -4,6 +4,7 @@ package dex
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
@@ -19,10 +20,13 @@ import (
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
jose "github.com/go-jose/go-jose/v4"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
// Config matches what management/internals/server/server.go expects
@@ -666,3 +670,46 @@ func (p *Provider) GetAuthorizationEndpoint() string {
}
return issuer + "/auth"
}
// GetJWKS reads signing keys directly from Dex storage and returns them as Jwks.
// This avoids HTTP round-trips when the embedded IDP is co-located with the management server.
// The key retrieval mirrors Dex's own handlePublicKeys/ValidationKeys logic:
// SigningKeyPub first, then all VerificationKeys, serialized via go-jose.
func (p *Provider) GetJWKS(ctx context.Context) (*nbjwt.Jwks, error) {
keys, err := p.storage.GetKeys(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get keys from storage: %w", err)
}
if keys.SigningKeyPub == nil {
return nil, fmt.Errorf("no public keys found in storage")
}
// Build the key set exactly as Dex's localSigner.ValidationKeys does:
// signing key first, then all verification (rotated) keys.
joseKeys := make([]jose.JSONWebKey, 0, len(keys.VerificationKeys)+1)
joseKeys = append(joseKeys, *keys.SigningKeyPub)
for _, vk := range keys.VerificationKeys {
if vk.PublicKey != nil {
joseKeys = append(joseKeys, *vk.PublicKey)
}
}
// Serialize through go-jose (same as Dex's handlePublicKeys handler)
// then deserialize into our Jwks type, so the JSON field mapping is identical
// to what the /keys HTTP endpoint would return.
joseSet := jose.JSONWebKeySet{Keys: joseKeys}
data, err := json.Marshal(joseSet)
if err != nil {
return nil, fmt.Errorf("failed to marshal JWKS: %w", err)
}
jwks := &nbjwt.Jwks{}
if err := json.Unmarshal(data, jwks); err != nil {
return nil, fmt.Errorf("failed to unmarshal JWKS: %w", err)
}
jwks.ExpiresInTime = keys.NextRotation
return jwks, nil
}

View File

@@ -2,11 +2,14 @@ package dex
import (
"context"
"encoding/json"
"log/slog"
"os"
"path/filepath"
"testing"
"github.com/dexidp/dex/storage"
sqllib "github.com/dexidp/dex/storage/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -197,6 +200,295 @@ enablePasswordDB: true
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
}
// openTestStorage creates a SQLite storage in the given directory for testing.
func openTestStorage(t *testing.T, tmpDir string) storage.Storage {
t.Helper()
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
stor, err := (&sqllib.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
require.NoError(t, err)
return stor
}
func TestStaticConnectors_CreatedFromYAML(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-static-conn-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
yamlContent := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + filepath.Join(tmpDir, "dex.db") + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
connectors:
- type: oidc
id: my-oidc
name: My OIDC Provider
config:
issuer: https://accounts.example.com
clientID: test-client-id
clientSecret: test-client-secret
redirectURI: http://localhost:5556/dex/callback
`
configPath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
yamlConfig, err := LoadConfig(configPath)
require.NoError(t, err)
// Open storage and run initializeStorage directly (avoids Dex server
// trying to dial the OIDC issuer)
stor := openTestStorage(t, tmpDir)
defer stor.Close()
err = initializeStorage(ctx, stor, yamlConfig)
require.NoError(t, err)
// Verify connector was created in storage
conn, err := stor.GetConnector(ctx, "my-oidc")
require.NoError(t, err)
assert.Equal(t, "my-oidc", conn.ID)
assert.Equal(t, "My OIDC Provider", conn.Name)
assert.Equal(t, "oidc", conn.Type)
// Verify config fields were serialized correctly
var configMap map[string]interface{}
err = json.Unmarshal(conn.Config, &configMap)
require.NoError(t, err)
assert.Equal(t, "https://accounts.example.com", configMap["issuer"])
assert.Equal(t, "test-client-id", configMap["clientID"])
}
func TestStaticConnectors_UpdatedOnRestart(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-static-conn-update-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
dbFile := filepath.Join(tmpDir, "dex.db")
// First: load config with initial connector
yamlContent1 := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + dbFile + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
connectors:
- type: oidc
id: my-oidc
name: Original Name
config:
issuer: https://accounts.example.com
clientID: original-client-id
clientSecret: original-secret
`
configPath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configPath, []byte(yamlContent1), 0644)
require.NoError(t, err)
yamlConfig1, err := LoadConfig(configPath)
require.NoError(t, err)
stor := openTestStorage(t, tmpDir)
err = initializeStorage(ctx, stor, yamlConfig1)
require.NoError(t, err)
// Verify initial state
conn, err := stor.GetConnector(ctx, "my-oidc")
require.NoError(t, err)
assert.Equal(t, "Original Name", conn.Name)
var configMap1 map[string]interface{}
err = json.Unmarshal(conn.Config, &configMap1)
require.NoError(t, err)
assert.Equal(t, "original-client-id", configMap1["clientID"])
// Close storage to simulate restart
stor.Close()
// Second: load updated config against the same DB
yamlContent2 := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + dbFile + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
connectors:
- type: oidc
id: my-oidc
name: Updated Name
config:
issuer: https://accounts.example.com
clientID: updated-client-id
clientSecret: updated-secret
`
err = os.WriteFile(configPath, []byte(yamlContent2), 0644)
require.NoError(t, err)
yamlConfig2, err := LoadConfig(configPath)
require.NoError(t, err)
stor2 := openTestStorage(t, tmpDir)
defer stor2.Close()
err = initializeStorage(ctx, stor2, yamlConfig2)
require.NoError(t, err)
// Verify connector was updated, not duplicated
allConnectors, err := stor2.ListConnectors(ctx)
require.NoError(t, err)
nonLocalCount := 0
for _, c := range allConnectors {
if c.ID != "local" {
nonLocalCount++
}
}
assert.Equal(t, 1, nonLocalCount, "connector should be updated, not duplicated")
conn2, err := stor2.GetConnector(ctx, "my-oidc")
require.NoError(t, err)
assert.Equal(t, "Updated Name", conn2.Name)
var configMap2 map[string]interface{}
err = json.Unmarshal(conn2.Config, &configMap2)
require.NoError(t, err)
assert.Equal(t, "updated-client-id", configMap2["clientID"])
}
func TestStaticConnectors_MultipleConnectors(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-static-conn-multi-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
yamlContent := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + filepath.Join(tmpDir, "dex.db") + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
connectors:
- type: oidc
id: my-oidc
name: My OIDC Provider
config:
issuer: https://accounts.example.com
clientID: oidc-client-id
clientSecret: oidc-secret
- type: google
id: my-google
name: Google Login
config:
clientID: google-client-id
clientSecret: google-secret
`
configPath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
yamlConfig, err := LoadConfig(configPath)
require.NoError(t, err)
stor := openTestStorage(t, tmpDir)
defer stor.Close()
err = initializeStorage(ctx, stor, yamlConfig)
require.NoError(t, err)
allConnectors, err := stor.ListConnectors(ctx)
require.NoError(t, err)
// Build a map for easier assertion
connByID := make(map[string]storage.Connector)
for _, c := range allConnectors {
connByID[c.ID] = c
}
// Verify both static connectors exist
oidcConn, ok := connByID["my-oidc"]
require.True(t, ok, "oidc connector should exist")
assert.Equal(t, "My OIDC Provider", oidcConn.Name)
assert.Equal(t, "oidc", oidcConn.Type)
var oidcConfig map[string]interface{}
err = json.Unmarshal(oidcConn.Config, &oidcConfig)
require.NoError(t, err)
assert.Equal(t, "oidc-client-id", oidcConfig["clientID"])
googleConn, ok := connByID["my-google"]
require.True(t, ok, "google connector should exist")
assert.Equal(t, "Google Login", googleConn.Name)
assert.Equal(t, "google", googleConn.Type)
var googleConfig map[string]interface{}
err = json.Unmarshal(googleConn.Config, &googleConfig)
require.NoError(t, err)
assert.Equal(t, "google-client-id", googleConfig["clientID"])
// Verify local connector still exists alongside them (enablePasswordDB: true)
localConn, ok := connByID["local"]
require.True(t, ok, "local connector should exist")
assert.Equal(t, "local", localConn.Type)
}
func TestStaticConnectors_EmptyList(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-static-conn-empty-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
yamlContent := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + filepath.Join(tmpDir, "dex.db") + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
`
configPath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
yamlConfig, err := LoadConfig(configPath)
require.NoError(t, err)
provider, err := NewProviderFromYAML(ctx, yamlConfig)
require.NoError(t, err)
defer func() { _ = provider.Stop(ctx) }()
// No static connectors configured, so ListConnectors should return empty
connectors, err := provider.ListConnectors(ctx)
require.NoError(t, err)
assert.Empty(t, connectors)
// But local connector should still exist
localConn, err := provider.Storage().GetConnector(ctx, "local")
require.NoError(t, err)
assert.Equal(t, "local", localConn.ID)
}
func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) {
ctx := context.Background()

View File

@@ -172,8 +172,11 @@ init_environment() {
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
echo ""
echo "Login with the following credentials:"
echo "Email: admin@$NETBIRD_DOMAIN" | tee .env
echo "Password: $NETBIRD_ADMIN_PASSWORD" | tee -a .env
install -m 600 /dev/null .env
printf 'Email: admin@%s\nPassword: %s\n' \
"$NETBIRD_DOMAIN" "$NETBIRD_ADMIN_PASSWORD" >> .env
echo "Email: admin@$NETBIRD_DOMAIN"
echo "Password: $NETBIRD_ADMIN_PASSWORD"
echo ""
echo "Dex admin UI is not available (Dex has no built-in UI)."
echo "To add more users, edit dex.yaml and restart: $DOCKER_COMPOSE_COMMAND restart dex"

View File

@@ -563,8 +563,11 @@ initEnvironment() {
echo -e "\nDone!\n"
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
echo "Login with the following credentials:"
echo "Username: $ZITADEL_ADMIN_USERNAME" | tee .env
echo "Password: $ZITADEL_ADMIN_PASSWORD" | tee -a .env
install -m 600 /dev/null .env
printf 'Username: %s\nPassword: %s\n' \
"$ZITADEL_ADMIN_USERNAME" "$ZITADEL_ADMIN_PASSWORD" >> .env
echo "Username: $ZITADEL_ADMIN_USERNAME"
echo "Password: $ZITADEL_ADMIN_PASSWORD"
}
renderCaddyfile() {

View File

@@ -1154,7 +1154,16 @@ print_builtin_traefik_instructions() {
echo " - $NETBIRD_STUN_PORT/udp (STUN - required for NAT traversal)"
if [[ "$ENABLE_PROXY" == "true" ]]; then
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
echo ""
fi
echo ""
echo "This setup is ideal for homelabs and smaller organization deployments."
echo "For enterprise environments requiring high availability and advanced integrations,"
echo "consider a commercial on-prem license or scaling your open source deployment:"
echo ""
echo " Commercial license: https://netbird.io/pricing#on-prem"
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
echo ""
if [[ "$ENABLE_PROXY" == "true" ]]; then
echo "NetBird Proxy:"
echo " The proxy service is enabled and running."
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."

View File

@@ -154,9 +154,11 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
return err
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") {
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
}
return nil
})

View File

@@ -31,19 +31,15 @@ type store interface {
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
}
type clusterCapabilities interface {
ClusterSupportsCustomPorts(clusterAddr string) *bool
ClusterRequireSubdomain(clusterAddr string) *bool
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
}
type Manager struct {
store store
validator domain.Validator
proxyManager proxyManager
clusterCapabilities clusterCapabilities
permissionsManager permissions.Manager
store store
validator domain.Validator
proxyManager proxyManager
permissionsManager permissions.Manager
accountManager account.Manager
}
@@ -57,11 +53,6 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio
}
}
// SetClusterCapabilities sets the cluster capabilities provider for domain queries.
func (m *Manager) SetClusterCapabilities(caps clusterCapabilities) {
m.clusterCapabilities = caps
}
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
@@ -97,10 +88,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
Type: domain.TypeFree,
Validated: true,
}
if m.clusterCapabilities != nil {
d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster)
d.RequireSubdomain = m.clusterCapabilities.ClusterRequireSubdomain(cluster)
}
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
ret = append(ret, d)
}
@@ -114,8 +103,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
Type: domain.TypeCustom,
Validated: d.Validated,
}
if m.clusterCapabilities != nil && d.TargetCluster != "" {
cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster)
if d.TargetCluster != "" {
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
}
// Custom domains never require a subdomain by default since
// the account owns them and should be able to use the bare domain.

View File

@@ -11,11 +11,13 @@ import (
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
}
@@ -34,6 +36,4 @@ type Controller interface {
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
ClusterSupportsCustomPorts(clusterAddr string) *bool
ClusterRequireSubdomain(clusterAddr string) *bool
}

View File

@@ -72,17 +72,6 @@ func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, cluster
return nil
}
// ClusterSupportsCustomPorts returns whether any proxy in the cluster supports custom ports.
func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr)
}
// ClusterRequireSubdomain returns whether the cluster requires a subdomain label.
// Returns nil when no proxy has reported the capability (defaults to false).
func (c *GRPCController) ClusterRequireSubdomain(clusterAddr string) *bool {
return c.proxyGRPCServer.ClusterRequireSubdomain(clusterAddr)
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)

View File

@@ -16,6 +16,8 @@ type store interface {
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
}
@@ -38,9 +40,14 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
}, nil
}
// Connect registers a new proxy connection in the database
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
// Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
now := time.Now()
var caps proxy.Capabilities
if capabilities != nil {
caps = *capabilities
}
p := &proxy.Proxy{
ID: proxyID,
ClusterAddress: clusterAddress,
@@ -48,6 +55,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
Capabilities: caps,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
@@ -118,6 +126,18 @@ func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error)
return clusters, nil
}
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
// supports custom ports. Returns nil when no proxy has reported capabilities.
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
return m.store.GetClusterSupportsCustomPorts(ctx, clusterAddr)
}
// ClusterRequireSubdomain returns whether any active proxy in the cluster
// requires a subdomain. Returns nil when no proxy has reported capabilities.
func (m Manager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
return m.store.GetClusterRequireSubdomain(ctx, clusterAddr)
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {

View File

@@ -50,18 +50,46 @@ func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
// ClusterSupportsCustomPorts mocks base method.
func (m *MockManager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts.
func (mr *MockManagerMockRecorder) ClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCustomPorts), ctx, clusterAddr)
}
// ClusterRequireSubdomain mocks base method.
func (m *MockManager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterRequireSubdomain", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain.
func (mr *MockManagerMockRecorder) ClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockManager)(nil).ClusterRequireSubdomain), ctx, clusterAddr)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
ret0, _ := ret[0].(error)
return ret0
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
}
// Disconnect mocks base method.
@@ -145,34 +173,6 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder {
return m.recorder
}
// ClusterSupportsCustomPorts mocks base method.
func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts.
func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr)
}
// ClusterRequireSubdomain mocks base method.
func (m *MockController) ClusterRequireSubdomain(clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterRequireSubdomain", clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain.
func (mr *MockControllerMockRecorder) ClusterRequireSubdomain(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockController)(nil).ClusterRequireSubdomain), clusterAddr)
}
// GetOIDCValidationConfig mocks base method.
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()

View File

@@ -2,6 +2,17 @@ package proxy
import "time"
// Capabilities describes what a proxy can handle, as reported via gRPC.
// Nil fields mean the proxy never reported this capability.
type Capabilities struct {
// SupportsCustomPorts indicates whether this proxy can bind arbitrary
// ports for TCP/UDP services. TLS uses SNI routing and is not gated.
SupportsCustomPorts *bool
// RequireSubdomain indicates whether a subdomain label is required in
// front of the cluster domain.
RequireSubdomain *bool
}
// Proxy represents a reverse proxy instance
type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"`
@@ -11,6 +22,7 @@ type Proxy struct {
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Capabilities Capabilities `gorm:"embedded"`
CreatedAt time.Time
UpdatedAt time.Time
}

View File

@@ -75,16 +75,18 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
require.NoError(t, err)
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes()
mockCtrl.EXPECT().ClusterRequireSubdomain(gomock.Any()).Return((*bool)(nil)).AnyTimes()
mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
mockCaps := proxy.NewMockManager(ctrl)
mockCaps.EXPECT().ClusterSupportsCustomPorts(gomock.Any(), testCluster).Return(customPortsSupported).AnyTimes()
mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes()
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
},
}
@@ -93,6 +95,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
accountManager: accountMgr,
permissionsManager: permissions.NewManager(testStore),
proxyController: mockCtrl,
capabilities: mockCaps,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
}
mgr.exposeReaper = &exposeReaper{manager: mgr}

View File

@@ -75,22 +75,30 @@ type ClusterDeriver interface {
GetClusterDomains() []string
}
// CapabilityProvider queries proxy cluster capabilities from the database.
type CapabilityProvider interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
}
type Manager struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyController proxy.Controller
capabilities CapabilityProvider
clusterDeriver ClusterDeriver
exposeReaper *exposeReaper
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
mgr := &Manager{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyController: proxyController,
capabilities: capabilities,
clusterDeriver: clusterDeriver,
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
@@ -237,7 +245,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
}
service.ProxyCluster = proxyCluster
if err := m.validateSubdomainRequirement(service.Domain, proxyCluster); err != nil {
if err := m.validateSubdomainRequirement(ctx, service.Domain, proxyCluster); err != nil {
return err
}
}
@@ -268,11 +276,11 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
// validateSubdomainRequirement checks whether the domain can be used bare
// (without a subdomain label) on the given cluster. If the cluster reports
// require_subdomain=true and the domain equals the cluster domain, it rejects.
func (m *Manager) validateSubdomainRequirement(domain, cluster string) error {
func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, cluster string) error {
if domain != cluster {
return nil
}
requireSub := m.proxyController.ClusterRequireSubdomain(cluster)
requireSub := m.capabilities.ClusterRequireSubdomain(ctx, cluster)
if requireSub != nil && *requireSub {
return status.Errorf(status.InvalidArgument, "domain %s requires a subdomain label", domain)
}
@@ -280,6 +288,8 @@ func (m *Manager) validateSubdomainRequirement(domain, cluster string) error {
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if svc.Domain != "" {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
@@ -287,7 +297,7 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
}
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
return err
}
@@ -307,12 +317,23 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
})
}
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error {
// clusterCustomPorts queries whether the cluster supports custom ports.
// Must be called before entering a transaction: the underlying query uses
// the main DB handle, which deadlocks when called inside a transaction
// that already holds the connection.
func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service) *bool {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
return m.capabilities.ClusterSupportsCustomPorts(ctx, svc.ProxyCluster)
}
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
// customPorts must be pre-computed via clusterCustomPorts before entering a transaction.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster)
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
@@ -396,12 +417,14 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
return err
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
return err
}
@@ -504,21 +527,58 @@ type serviceUpdateInfo struct {
}
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
effectiveCluster, err := m.resolveEffectiveCluster(ctx, accountID, service)
if err != nil {
return nil, err
}
svcForCaps := *service
svcForCaps.ProxyCluster = effectiveCluster
customPorts := m.clusterCustomPorts(ctx, &svcForCaps)
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts)
})
return &updateInfo, err
}
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo) error {
// resolveEffectiveCluster determines the cluster that will be used after the update.
// It reads the existing service without locking and derives the new cluster if the domain changed.
func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string, svc *service.Service) (string, error) {
existing, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, svc.ID)
if err != nil {
return "", err
}
if existing.Domain == svc.Domain {
return existing.ProxyCluster, nil
}
if m.clusterDeriver != nil {
derived, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
} else {
return derived, nil
}
}
return existing.ProxyCluster, nil
}
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
if existingService.Terminated {
return status.Errorf(status.PermissionDenied, "service is terminated and cannot be updated")
}
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
return err
}
@@ -534,7 +594,7 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
service.ProxyCluster = existingService.ProxyCluster
}
if err := m.validateSubdomainRequirement(service.Domain, service.ProxyCluster); err != nil {
if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil {
return err
}
@@ -546,7 +606,7 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
@@ -1059,7 +1119,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
}
groupIDs := make([]string, 0, len(groupNames))
for _, groupName := range groupNames {
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator)
if err != nil {
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
}

View File

@@ -698,8 +698,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
},
}
@@ -1324,11 +1324,11 @@ func TestValidateSubdomainRequirement(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterRequireSubdomain(tc.cluster).Return(tc.requireSubdomain).AnyTimes()
mockCaps := proxy.NewMockManager(ctrl)
mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), tc.cluster).Return(tc.requireSubdomain).AnyTimes()
mgr := &Manager{proxyController: mockCtrl}
err := mgr.validateSubdomainRequirement(tc.domain, tc.cluster)
mgr := &Manager{capabilities: mockCaps}
err := mgr.validateSubdomainRequirement(context.Background(), tc.domain, tc.cluster)
if tc.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "requires a subdomain label")

View File

@@ -184,6 +184,7 @@ type Service struct {
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool
Terminated bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
@@ -256,7 +257,7 @@ func (s *Service) ToAPIResponse() *api.Service {
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
Enabled: target.Enabled && !s.Terminated,
}
opts := targetOptionsToAPI(target.Options)
if opts == nil {
@@ -286,7 +287,8 @@ func (s *Service) ToAPIResponse() *api.Service {
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
Enabled: s.Enabled && !s.Terminated,
Terminated: &s.Terminated,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
@@ -785,6 +787,11 @@ func (s *Service) validateHTTPTargets() error {
}
func (s *Service) validateL4Target(target *Target) error {
// L4 services have a single target; per-target disable is meaningless
// (use the service-level Enabled flag instead). Force it on so that
// buildPathMappings always includes the target in the proto.
target.Enabled = true
if target.Port == 0 {
return errors.New("target port is required for L4 services")
}
@@ -1125,6 +1132,7 @@ func (s *Service) Copy() *Service {
ProxyCluster: s.ProxyCluster,
Targets: targets,
Enabled: s.Enabled,
Terminated: s.Terminated,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: authCopy,

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
@@ -71,6 +72,7 @@ func (s *BaseServer) AuthManager() auth.Manager {
signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled
issuer := s.Config.HttpConfig.AuthIssuer
userIDClaim := s.Config.HttpConfig.AuthUserIDClaim
var keyFetcher nbjwt.KeyFetcher
// Use embedded IdP configuration if available
if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil {
@@ -78,8 +80,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
if len(audiences) > 0 {
audience = audiences[0] // Use the first client ID as the primary audience
}
// Use localhost keys location for internal validation (management has embedded Dex)
keysLocation = oauthProvider.GetLocalKeysLocation()
keyFetcher = oauthProvider.GetKeyFetcher()
// Fall back to default keys location if direct key fetching is not available
if keyFetcher == nil {
keysLocation = oauthProvider.GetLocalKeysLocation()
}
signingKeyRefreshEnabled = true
issuer = oauthProvider.GetIssuer()
userIDClaim = oauthProvider.GetUserIDClaim()
@@ -92,7 +97,8 @@ func (s *BaseServer) AuthManager() auth.Manager {
keysLocation,
userIDClaim,
audiences,
signingKeyRefreshEnabled)
signingKeyRefreshEnabled,
keyFetcher)
})
}

View File

@@ -117,9 +117,11 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
// Use embedded IdP service if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured.
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
embeddedEnabled := s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled
if embeddedEnabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil {
log.Fatalf("failed to create embedded IDP service: %v", err)
@@ -195,7 +197,7 @@ func (s *BaseServer) RecordsManager() records.Manager {
func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
})
}
@@ -212,9 +214,6 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
s.AfterInit(func(s *BaseServer) {
m.SetClusterCapabilities(s.ServiceProxyController())
})
return &m
})
}

View File

@@ -182,9 +182,21 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
// Register proxy in database with capabilities
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
caps = &proxy.Capabilities{
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
}
}
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
s.connectedProxies.Delete(proxyID)
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
}
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
log.WithFields(log.Fields{
@@ -297,6 +309,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
}
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
if !proxyAcceptsMapping(conn, m) {
continue
}
mappings = append(mappings, m)
}
return mappings, nil
@@ -445,22 +460,46 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
log.Debugf("Sending service update to cluster %s", clusterAddr)
for _, proxyID := range proxyIDs {
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil {
continue
}
select {
case conn.sendChan <- msg:
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
connVal, ok := s.connectedProxies.Load(proxyID)
if !ok {
continue
}
conn := connVal.(*proxyConnection)
if !proxyAcceptsMapping(conn, update) {
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
continue
}
msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil {
continue
}
select {
case conn.sendChan <- msg:
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
}
}
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
// Old proxies that never reported capabilities are skipped for non-TLS L4
// mappings with a custom listen port, since they don't understand the
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
// are new enough to handle the mapping. TLS uses SNI routing and works on
// any proxy. Delete operations are always sent so proxies can clean up.
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
return true
}
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
return true
}
// Old proxies that never reported capabilities don't understand
// custom port mappings.
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
}
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
@@ -508,64 +547,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
}
}
// ClusterSupportsCustomPorts returns whether any connected proxy in the given
// cluster reports custom port support. Returns nil if no proxy has reported
// capabilities (old proxies that predate the field).
func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool {
if s.proxyController == nil {
return nil
}
var hasCapabilities bool
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
connVal, ok := s.connectedProxies.Load(pid)
if !ok {
continue
}
conn := connVal.(*proxyConnection)
if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil {
continue
}
if *conn.capabilities.SupportsCustomPorts {
return ptr(true)
}
hasCapabilities = true
}
if hasCapabilities {
return ptr(false)
}
return nil
}
// ClusterRequireSubdomain returns whether any connected proxy in the given
// cluster reports that a subdomain is required. Returns nil if no proxy has
// reported the capability (defaults to not required).
func (s *ProxyServiceServer) ClusterRequireSubdomain(clusterAddr string) *bool {
if s.proxyController == nil {
return nil
}
var hasCapabilities bool
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
connVal, ok := s.connectedProxies.Load(pid)
if !ok {
continue
}
conn := connVal.(*proxyConnection)
if conn.capabilities == nil || conn.capabilities.RequireSubdomain == nil {
continue
}
if *conn.capabilities.RequireSubdomain {
return ptr(true)
}
hasCapabilities = true
}
if hasCapabilities {
return ptr(false)
}
return nil
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {

View File

@@ -53,14 +53,6 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus
return nil
}
func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool {
return ptr(true)
}
func (c *testProxyController) ClusterRequireSubdomain(_ string) *bool {
return nil
}
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
c.mu.Lock()
defer c.mu.Unlock()
@@ -355,14 +347,14 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
const cluster = "proxy.example.com"
// Proxy A supports custom ports.
chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
// Proxy B does NOT support custom ports (shared cloud proxy).
chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
// Modern proxy reports capabilities.
chModern := registerFakeProxyWithCaps(s, "proxy-modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
// Legacy proxy never reported capabilities (nil).
chLegacy := registerFakeProxy(s, "proxy-legacy", cluster)
ctx := context.Background()
// TLS passthrough works on all proxies regardless of custom port support.
// TLS passthrough with custom port: all proxies receive it (SNI routing).
tlsMapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-tls",
@@ -375,12 +367,26 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster)
msgA := drainMapping(chA)
msgB := drainMapping(chB)
assert.NotNil(t, msgA, "proxy-a should receive TLS mapping")
assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)")
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TLS mapping")
assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive TLS mapping (SNI works on all)")
// Send an HTTP mapping: both should receive it.
// TCP mapping with custom port: only modern proxy receives it.
tcpMapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-tcp",
AccountId: "account-1",
Domain: "db.example.com",
Mode: "tcp",
ListenPort: 5432,
Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}},
}
s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster)
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TCP custom-port mapping")
assert.Nil(t, drainMapping(chLegacy), "legacy proxy should NOT receive TCP custom-port mapping")
// HTTP mapping (no listen port): both receive it.
httpMapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-http",
@@ -391,10 +397,16 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
s.SendServiceUpdateToCluster(ctx, httpMapping, cluster)
msgA = drainMapping(chA)
msgB = drainMapping(chB)
assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping")
assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping")
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive HTTP mapping")
assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive HTTP mapping")
// Proxy that reports SupportsCustomPorts=false still receives custom-port
// mappings because it understands the protocol (it's new enough).
chNewNoCustom := registerFakeProxyWithCaps(s, "proxy-new-no-custom", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster)
assert.NotNil(t, drainMapping(chNewNoCustom), "new proxy with SupportsCustomPorts=false should still receive mapping")
}
func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
@@ -408,7 +420,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
const cluster = "proxy.example.com"
chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
// Legacy proxy (no capabilities) still receives TLS since it uses SNI.
chLegacy := registerFakeProxy(s, "proxy-legacy", cluster)
tlsMapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
@@ -421,8 +434,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster)
msg := drainMapping(chShared)
assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support")
msg := drainMapping(chLegacy)
assert.NotNil(t, msg, "legacy proxy should receive TLS mapping (SNI works without custom port support)")
}
// TestServiceModifyNotifications exercises every possible modification
@@ -589,7 +602,7 @@ func TestServiceModifyNotifications(t *testing.T) {
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
chLegacy := registerFakeProxy(s, "legacy", cluster)
// TLS passthrough works on all proxies regardless of custom port support
s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster)
@@ -608,7 +621,7 @@ func TestServiceModifyNotifications(t *testing.T) {
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
chLegacy := registerFakeProxy(s, "legacy", cluster)
mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
mapping.ListenPort = 0 // default port

View File

@@ -75,7 +75,7 @@ type Manager interface {
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error

View File

@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
}
// GetGroupByName mocks base method.
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
ret0, _ := ret[0].(*types.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
}
// GetIdentityProvider mocks base method.

View File

@@ -3138,7 +3138,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
if err != nil {
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil))
return manager, updateManager, nil
}

View File

@@ -0,0 +1,61 @@
package store
// This file contains migration-only methods on Store.
// They satisfy the migration.MigrationEventStore interface via duck typing.
// Delete this file when migration tooling is no longer needed.
import (
"context"
"fmt"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp/migration"
)
// CheckSchema verifies that all tables and columns required by the migration exist in the event database.
func (store *Store) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError {
migrator := store.db.Migrator()
var errs []migration.SchemaError
for _, check := range checks {
if !migrator.HasTable(check.Table) {
errs = append(errs, migration.SchemaError{Table: check.Table})
continue
}
for _, col := range check.Columns {
if !migrator.HasColumn(check.Table, col) {
errs = append(errs, migration.SchemaError{Table: check.Table, Column: col})
}
}
}
return errs
}
// UpdateUserID updates all references to oldUserID in events and deleted_users tables.
func (store *Store) UpdateUserID(ctx context.Context, oldUserID, newUserID string) error {
return store.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&activity.Event{}).
Where("initiator_id = ?", oldUserID).
Update("initiator_id", newUserID).Error; err != nil {
return fmt.Errorf("update events.initiator_id: %w", err)
}
if err := tx.Model(&activity.Event{}).
Where("target_id = ?", oldUserID).
Update("target_id", newUserID).Error; err != nil {
return fmt.Errorf("update events.target_id: %w", err)
}
// Raw exec: GORM can't update a PK via Model().Update()
if err := tx.Exec(
"UPDATE deleted_users SET id = ? WHERE id = ?", newUserID, oldUserID,
).Error; err != nil {
return fmt.Errorf("update deleted_users.id: %w", err)
}
return nil
})
}

View File

@@ -0,0 +1,161 @@
package store
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util/crypt"
)
func TestUpdateUserID(t *testing.T) {
ctx := context.Background()
newStore := func(t *testing.T) *Store {
t.Helper()
key, _ := crypt.GenerateKey()
s, err := NewSqlStore(ctx, t.TempDir(), key)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { s.Close(ctx) }) //nolint
return s
}
t.Run("updates initiator_id in events", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "old-user",
TargetID: "some-peer",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, "new-user", result[0].InitiatorID)
})
t.Run("updates target_id in events", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "some-admin",
TargetID: "old-user",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, "new-user", result[0].TargetID)
})
t.Run("updates deleted_users id", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
// Save an event with email/name meta to create a deleted_users row for "old-user"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "admin",
TargetID: "old-user",
AccountID: accountID,
Meta: map[string]any{
"email": "user@example.com",
"name": "Test User",
},
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
// Save another event referencing new-user with email/name meta.
// This should upsert (not conflict) because the PK was already migrated.
_, err = store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "admin",
TargetID: "new-user",
AccountID: accountID,
Meta: map[string]any{
"email": "user@example.com",
"name": "Test User",
},
})
assert.NoError(t, err)
// The deleted user info should be retrievable via Get (joined on target_id)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 2)
for _, ev := range result {
assert.Equal(t, "new-user", ev.TargetID)
}
})
t.Run("no-op when old user ID does not exist", func(t *testing.T) {
store := newStore(t)
err := store.UpdateUserID(ctx, "nonexistent-user", "new-user")
assert.NoError(t, err)
})
t.Run("only updates matching user leaves others unchanged", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user-a",
TargetID: "peer-1",
AccountID: accountID,
})
assert.NoError(t, err)
_, err = store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user-b",
TargetID: "peer-2",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "user-a", "user-a-new")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 2)
for _, ev := range result {
if ev.TargetID == "peer-1" {
assert.Equal(t, "user-a-new", ev.InitiatorID)
} else {
assert.Equal(t, "user-b", ev.InitiatorID)
}
}
})
}

View File

@@ -33,15 +33,20 @@ type manager struct {
extractor *nbjwt.ClaimsExtractor
}
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager {
// @note if invalid/missing parameters are sent the validator will instantiate
// but it will fail when validating and parsing the token
jwtValidator := nbjwt.NewValidator(
issuer,
allAudiences,
keysLocation,
idpRefreshKeys,
)
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool, keyFetcher nbjwt.KeyFetcher) Manager {
var jwtValidator *nbjwt.Validator
if keyFetcher != nil {
jwtValidator = nbjwt.NewValidatorWithKeyFetcher(issuer, allAudiences, keyFetcher)
} else {
// @note if invalid/missing parameters are sent the validator will instantiate
// but it will fail when validating and parsing the token
jwtValidator = nbjwt.NewValidator(
issuer,
allAudiences,
keysLocation,
idpRefreshKeys,
)
}
claimsExtractor := nbjwt.NewClaimsExtractor(
nbjwt.WithAudience(audience),

View File

@@ -52,7 +52,7 @@ func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
if err != nil {
@@ -92,7 +92,7 @@ func TestAuthManager_MarkPATUsed(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
err = manager.MarkPATUsed(context.Background(), "tokenId")
if err != nil {
@@ -142,7 +142,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
// these tests only assert groups are parsed from token as per account settings
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
t.Run("JWT groups disabled", func(t *testing.T) {
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
@@ -225,7 +225,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
keyId := "test-key"
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false, nil)
customClaim := func(name string) string {
return fmt.Sprintf("%s/%s", audience, name)

View File

@@ -130,6 +130,10 @@ func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) {
gl.mux.RLock()
defer gl.mux.RUnlock()
if gl.db == nil {
return nil, fmt.Errorf("geolocation database is not available")
}
var record Record
err := gl.db.Lookup(ip, &record)
if err != nil {
@@ -173,8 +177,14 @@ func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, er
func (gl *geolocationImpl) Stop() error {
close(gl.stopCh)
if gl.db != nil {
if err := gl.db.Close(); err != nil {
gl.mux.Lock()
db := gl.db
gl.db = nil
gl.mux.Unlock()
if db != nil {
if err := db.Close(); err != nil {
return err
}
}

View File

@@ -61,7 +61,10 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}

View File

@@ -52,7 +52,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
groupName := r.URL.Query().Get("name")
if groupName != "" {
// Get single group by name
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -118,7 +118,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
return
}
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -71,7 +71,7 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return groups, nil
},
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
}

View File

@@ -0,0 +1,238 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Accounts_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, true},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all accounts", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/accounts", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Account{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
account := got[0]
assert.Equal(t, "test.com", account.Domain)
assert.Equal(t, "private", account.DomainCategory)
assert.Equal(t, true, account.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, 86400, account.Settings.PeerLoginExpiration)
assert.Equal(t, false, account.Settings.RegularUsersViewBlocked)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Accounts_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
trueVal := true
falseVal := false
tt := []struct {
name string
expectedStatus int
requestBody *api.AccountRequest
verifyResponse func(t *testing.T, account *api.Account)
verifyDB func(t *testing.T, account *types.Account)
}{
{
name: "Disable peer login expiration",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: false,
PeerLoginExpiration: 86400,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.Equal(t, false, account.Settings.PeerLoginExpirationEnabled)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, false, dbAccount.Settings.PeerLoginExpirationEnabled)
},
},
{
name: "Update peer login expiration to 48h",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 172800,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.Equal(t, 172800, account.Settings.PeerLoginExpiration)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, 172800*time.Second, dbAccount.Settings.PeerLoginExpiration)
},
},
{
name: "Enable regular users view blocked",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 86400,
RegularUsersViewBlocked: true,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.Equal(t, true, account.Settings.RegularUsersViewBlocked)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, true, dbAccount.Settings.RegularUsersViewBlocked)
},
},
{
name: "Enable groups propagation",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 86400,
GroupsPropagationEnabled: &trueVal,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.NotNil(t, account.Settings.GroupsPropagationEnabled)
assert.Equal(t, true, *account.Settings.GroupsPropagationEnabled)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, true, dbAccount.Settings.GroupsPropagationEnabled)
},
},
{
name: "Enable JWT groups",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 86400,
GroupsPropagationEnabled: &falseVal,
JwtGroupsEnabled: &trueVal,
JwtGroupsClaimName: stringPointer("groups"),
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.NotNil(t, account.Settings.JwtGroupsEnabled)
assert.Equal(t, true, *account.Settings.JwtGroupsEnabled)
assert.NotNil(t, account.Settings.JwtGroupsClaimName)
assert.Equal(t, "groups", *account.Settings.JwtGroupsClaimName)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, true, dbAccount.Settings.JWTGroupsEnabled)
assert.Equal(t, "groups", dbAccount.Settings.JWTGroupsClaimName)
},
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/accounts/{accountId}", "{accountId}", testing_tools.TestAccountId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
got := &api.Account{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, testing_tools.TestAccountId, got.Id)
assert.Equal(t, "test.com", got.Domain)
tc.verifyResponse(t, got)
db := testing_tools.GetDB(t, am.GetStore())
dbAccount := testing_tools.VerifyAccountSettings(t, db)
tc.verifyDB(t, dbAccount)
})
}
}
}
func stringPointer(s string) *string {
return &s
}

View File

@@ -0,0 +1,554 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Nameservers_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all nameservers", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/nameservers", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.NameserverGroup{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "testNSGroup", got[0].Name)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Nameservers_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
nsGroupId string
expectedStatus int
expectGroup bool
}{
{
name: "Get existing nameserver group",
nsGroupId: "testNSGroupId",
expectedStatus: http.StatusOK,
expectGroup: true,
},
{
name: "Get non-existing nameserver group",
nsGroupId: "nonExistingNSGroupId",
expectedStatus: http.StatusNotFound,
expectGroup: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectGroup {
got := &api.NameserverGroup{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "testNSGroupId", got.Id)
assert.Equal(t, "testNSGroup", got.Name)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Nameservers_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
requestBody *api.PostApiDnsNameserversJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, nsGroup *api.NameserverGroup)
}{
{
name: "Create nameserver group with single NS",
requestBody: &api.PostApiDnsNameserversJSONRequestBody{
Name: "newNSGroup",
Description: "a new nameserver group",
Nameservers: []api.Nameserver{
{Ip: "8.8.8.8", NsType: "udp", Port: 53},
},
Groups: []string{testing_tools.TestGroupId},
Primary: false,
Domains: []string{"test.com"},
Enabled: true,
SearchDomainsEnabled: false,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) {
t.Helper()
assert.NotEmpty(t, nsGroup.Id)
assert.Equal(t, "newNSGroup", nsGroup.Name)
assert.Equal(t, 1, len(nsGroup.Nameservers))
assert.Equal(t, false, nsGroup.Primary)
},
},
{
name: "Create primary nameserver group",
requestBody: &api.PostApiDnsNameserversJSONRequestBody{
Name: "primaryNS",
Description: "primary nameserver",
Nameservers: []api.Nameserver{
{Ip: "1.1.1.1", NsType: "udp", Port: 53},
},
Groups: []string{testing_tools.TestGroupId},
Primary: true,
Domains: []string{},
Enabled: true,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) {
t.Helper()
assert.Equal(t, true, nsGroup.Primary)
},
},
{
name: "Create nameserver group with empty groups",
requestBody: &api.PostApiDnsNameserversJSONRequestBody{
Name: "emptyGroupsNS",
Description: "no groups",
Nameservers: []api.Nameserver{
{Ip: "8.8.8.8", NsType: "udp", Port: 53},
},
Groups: []string{},
Primary: true,
Domains: []string{},
Enabled: true,
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/dns/nameservers", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.NameserverGroup{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the created NS group directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbNS := testing_tools.VerifyNSGroupInDB(t, db, got.Id)
assert.Equal(t, got.Name, dbNS.Name)
assert.Equal(t, got.Primary, dbNS.Primary)
assert.Equal(t, len(got.Nameservers), len(dbNS.NameServers))
assert.Equal(t, got.Enabled, dbNS.Enabled)
assert.Equal(t, got.SearchDomainsEnabled, dbNS.SearchDomainsEnabled)
}
})
}
}
}
func Test_Nameservers_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
nsGroupId string
requestBody *api.PutApiDnsNameserversNsgroupIdJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, nsGroup *api.NameserverGroup)
}{
{
name: "Update nameserver group name",
nsGroupId: "testNSGroupId",
requestBody: &api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
Name: "updatedNSGroup",
Description: "updated description",
Nameservers: []api.Nameserver{
{Ip: "1.1.1.1", NsType: "udp", Port: 53},
},
Groups: []string{testing_tools.TestGroupId},
Primary: false,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: false,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) {
t.Helper()
assert.Equal(t, "updatedNSGroup", nsGroup.Name)
assert.Equal(t, "updated description", nsGroup.Description)
},
},
{
name: "Update non-existing nameserver group",
nsGroupId: "nonExistingNSGroupId",
requestBody: &api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
Name: "whatever",
Nameservers: []api.Nameserver{
{Ip: "1.1.1.1", NsType: "udp", Port: 53},
},
Groups: []string{testing_tools.TestGroupId},
Primary: true,
Domains: []string{},
Enabled: true,
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.NameserverGroup{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the updated NS group directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbNS := testing_tools.VerifyNSGroupInDB(t, db, tc.nsGroupId)
assert.Equal(t, "updatedNSGroup", dbNS.Name)
assert.Equal(t, "updated description", dbNS.Description)
assert.Equal(t, false, dbNS.Primary)
assert.Equal(t, true, dbNS.Enabled)
assert.Equal(t, 1, len(dbNS.NameServers))
assert.Equal(t, false, dbNS.SearchDomainsEnabled)
}
})
}
}
}
func Test_Nameservers_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
nsGroupId string
expectedStatus int
}{
{
name: "Delete existing nameserver group",
nsGroupId: "testNSGroupId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing nameserver group",
nsGroupId: "nonExistingNSGroupId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
// Verify deletion in DB for successful deletes by privileged users
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyNSGroupNotInDB(t, db, tc.nsGroupId)
}
})
}
}
}
func Test_DnsSettings_Get(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get DNS settings", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/settings", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := &api.DNSSettings{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.NotNil(t, got.DisabledManagementGroups)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_DnsSettings_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
requestBody *api.PutApiDnsSettingsJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, settings *api.DNSSettings)
expectedDBDisabledMgmtLen int
expectedDBDisabledMgmtItem string
}{
{
name: "Update disabled management groups",
requestBody: &api.PutApiDnsSettingsJSONRequestBody{
DisabledManagementGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, settings *api.DNSSettings) {
t.Helper()
assert.Equal(t, 1, len(settings.DisabledManagementGroups))
assert.Equal(t, testing_tools.TestGroupId, settings.DisabledManagementGroups[0])
},
expectedDBDisabledMgmtLen: 1,
expectedDBDisabledMgmtItem: testing_tools.TestGroupId,
},
{
name: "Update with empty disabled management groups",
requestBody: &api.PutApiDnsSettingsJSONRequestBody{
DisabledManagementGroups: []string{},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, settings *api.DNSSettings) {
t.Helper()
assert.Equal(t, 0, len(settings.DisabledManagementGroups))
},
expectedDBDisabledMgmtLen: 0,
},
{
name: "Update with non-existing group",
requestBody: &api.PutApiDnsSettingsJSONRequestBody{
DisabledManagementGroups: []string{"nonExistingGroupId"},
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/dns/settings", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.DNSSettings{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify DNS settings directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbAccount := testing_tools.VerifyAccountSettings(t, db)
assert.Equal(t, tc.expectedDBDisabledMgmtLen, len(dbAccount.DNSSettings.DisabledManagementGroups))
if tc.expectedDBDisabledMgmtItem != "" {
assert.Contains(t, dbAccount.DNSSettings.DisabledManagementGroups, tc.expectedDBDisabledMgmtItem)
}
}
})
}
}
}

Some files were not shown because too many files have changed in this diff Show More