mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Merge remote-tracking branch 'origin/main' into iptables-mangle-dnat-guard
This commit is contained in:
@@ -31,7 +31,7 @@ jobs:
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -61,8 +61,8 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 57671680 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -154,6 +154,26 @@ builds:
|
||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-idp-migrate
|
||||
dir: tools/idp-migrate
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
binary: netbird-idp-migrate
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
@@ -166,6 +186,10 @@ archives:
|
||||
- netbird-wasm
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
||||
format: binary
|
||||
- id: netbird-idp-migrate
|
||||
builds:
|
||||
- netbird-idp-migrate
|
||||
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
|
||||
nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
## Contributor License Agreement
|
||||
|
||||
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
||||
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
|
||||
submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany,
|
||||
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
||||
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
||||
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
||||
|
||||
@@ -205,7 +205,7 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
pi := PeerInfo{
|
||||
p.IP,
|
||||
p.FQDN,
|
||||
p.ConnStatus.String(),
|
||||
int(p.ConnStatus),
|
||||
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||
}
|
||||
peerInfos[n] = pi
|
||||
|
||||
@@ -2,11 +2,20 @@
|
||||
|
||||
package android
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||
|
||||
// Connection status constants exported via gomobile.
|
||||
const (
|
||||
ConnStatusIdle = int(peer.StatusIdle)
|
||||
ConnStatusConnecting = int(peer.StatusConnecting)
|
||||
ConnStatusConnected = int(peer.StatusConnected)
|
||||
)
|
||||
|
||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||
type PeerInfo struct {
|
||||
IP string
|
||||
FQDN string
|
||||
ConnStatus string // Todo replace to enum
|
||||
ConnStatus int
|
||||
Routes PeerRoutes
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,9 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -200,7 +202,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
|
||||
stream, err := client.ExposeService(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expose service: %w", err)
|
||||
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||
@@ -211,26 +213,31 @@ func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||
switch strings.ToLower(exposeProtocol) {
|
||||
case "http":
|
||||
p, err := expose.ParseProtocolType(exposeProtocol)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid protocol: %w", err)
|
||||
}
|
||||
|
||||
switch p {
|
||||
case expose.ProtocolHTTP:
|
||||
return proto.ExposeProtocol_EXPOSE_HTTP, nil
|
||||
case "https":
|
||||
case expose.ProtocolHTTPS:
|
||||
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
|
||||
case "tcp":
|
||||
case expose.ProtocolTCP:
|
||||
return proto.ExposeProtocol_EXPOSE_TCP, nil
|
||||
case "udp":
|
||||
case expose.ProtocolUDP:
|
||||
return proto.ExposeProtocol_EXPOSE_UDP, nil
|
||||
case "tls":
|
||||
case expose.ProtocolTLS:
|
||||
return proto.ExposeProtocol_EXPOSE_TLS, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
|
||||
return 0, fmt.Errorf("unhandled protocol type: %d", p)
|
||||
}
|
||||
}
|
||||
|
||||
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive expose event: %w", err)
|
||||
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
||||
|
||||
@@ -41,7 +41,7 @@ func init() {
|
||||
defaultServiceName = "Netbird"
|
||||
}
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||
|
||||
|
||||
@@ -119,6 +119,10 @@ var installCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||
}
|
||||
|
||||
svcConfig, err := createServiceConfigForInstall()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -136,6 +140,10 @@ var installCmd = &cobra.Command{
|
||||
return fmt.Errorf("install service: %w", err)
|
||||
}
|
||||
|
||||
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||
}
|
||||
|
||||
cmd.Println("NetBird service has been installed")
|
||||
return nil
|
||||
},
|
||||
@@ -187,6 +195,10 @@ This command will temporarily stop the service, update its configuration, and re
|
||||
return err
|
||||
}
|
||||
|
||||
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||
}
|
||||
|
||||
wasRunning, err := isServiceRunning()
|
||||
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
|
||||
return fmt.Errorf("check service status: %w", err)
|
||||
@@ -222,6 +234,10 @@ This command will temporarily stop the service, update its configuration, and re
|
||||
return fmt.Errorf("install service with new config: %w", err)
|
||||
}
|
||||
|
||||
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||
}
|
||||
|
||||
if wasRunning {
|
||||
cmd.Println("Starting NetBird service...")
|
||||
if err := s.Start(); err != nil {
|
||||
|
||||
201
client/cmd/service_params.go
Normal file
201
client/cmd/service_params.go
Normal file
@@ -0,0 +1,201 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const serviceParamsFile = "service.json"
|
||||
|
||||
// serviceParams holds install-time service parameters that persist across
|
||||
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
|
||||
type serviceParams struct {
|
||||
LogLevel string `json:"log_level"`
|
||||
DaemonAddr string `json:"daemon_addr"`
|
||||
ManagementURL string `json:"management_url,omitempty"`
|
||||
ConfigPath string `json:"config_path,omitempty"`
|
||||
LogFiles []string `json:"log_files,omitempty"`
|
||||
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||
}
|
||||
|
||||
// serviceParamsPath returns the path to the service params file.
|
||||
func serviceParamsPath() string {
|
||||
return filepath.Join(configs.StateDir, serviceParamsFile)
|
||||
}
|
||||
|
||||
// loadServiceParams reads saved service parameters from disk.
|
||||
// Returns nil with no error if the file does not exist.
|
||||
func loadServiceParams() (*serviceParams, error) {
|
||||
path := serviceParamsPath()
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
return nil, fmt.Errorf("read service params %s: %w", path, err)
|
||||
}
|
||||
|
||||
var params serviceParams
|
||||
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||
return nil, fmt.Errorf("parse service params %s: %w", path, err)
|
||||
}
|
||||
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
// saveServiceParams writes current service parameters to disk atomically
|
||||
// with restricted permissions.
|
||||
func saveServiceParams(params *serviceParams) error {
|
||||
path := serviceParamsPath()
|
||||
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
|
||||
return fmt.Errorf("save service params: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// currentServiceParams captures the current state of all package-level
|
||||
// variables into a serviceParams struct.
|
||||
func currentServiceParams() *serviceParams {
|
||||
params := &serviceParams{
|
||||
LogLevel: logLevel,
|
||||
DaemonAddr: daemonAddr,
|
||||
ManagementURL: managementURL,
|
||||
ConfigPath: configPath,
|
||||
LogFiles: logFiles,
|
||||
DisableProfiles: profilesDisabled,
|
||||
DisableUpdateSettings: updateSettingsDisabled,
|
||||
}
|
||||
|
||||
if len(serviceEnvVars) > 0 {
|
||||
parsed, err := parseServiceEnvVars(serviceEnvVars)
|
||||
if err == nil && len(parsed) > 0 {
|
||||
params.ServiceEnvVars = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// loadAndApplyServiceParams loads saved params from disk and applies them
|
||||
// to any flags that were not explicitly set.
|
||||
func loadAndApplyServiceParams(cmd *cobra.Command) error {
|
||||
params, err := loadServiceParams()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
applyServiceParams(cmd, params)
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyServiceParams merges saved parameters into package-level variables
|
||||
// for any flag that was not explicitly set by the user (via CLI or env var).
|
||||
// Flags that were Changed() are left untouched.
|
||||
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||
if params == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// For fields with non-empty defaults (log-level, daemon-addr), keep the
|
||||
// != "" guard so that an older service.json missing the field doesn't
|
||||
// clobber the default with an empty string.
|
||||
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
|
||||
logLevel = params.LogLevel
|
||||
}
|
||||
|
||||
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
|
||||
daemonAddr = params.DaemonAddr
|
||||
}
|
||||
|
||||
// For optional fields where empty means "use default", always apply so
|
||||
// that an explicit clear (--management-url "") persists across reinstalls.
|
||||
if !rootCmd.PersistentFlags().Changed("management-url") {
|
||||
managementURL = params.ManagementURL
|
||||
}
|
||||
|
||||
if !rootCmd.PersistentFlags().Changed("config") {
|
||||
configPath = params.ConfigPath
|
||||
}
|
||||
|
||||
if !rootCmd.PersistentFlags().Changed("log-file") {
|
||||
logFiles = params.LogFiles
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
|
||||
profilesDisabled = params.DisableProfiles
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
|
||||
updateSettingsDisabled = params.DisableUpdateSettings
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, params)
|
||||
}
|
||||
|
||||
// applyServiceEnvParams merges saved service environment variables.
|
||||
// If --service-env was explicitly set, explicit values win on key conflict
|
||||
// but saved keys not in the explicit set are carried over.
|
||||
// If --service-env was not set, saved env vars are used entirely.
|
||||
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
|
||||
if len(params.ServiceEnvVars) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if !cmd.Flags().Changed("service-env") {
|
||||
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
||||
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
||||
return
|
||||
}
|
||||
|
||||
// Explicit env vars were provided: merge saved values underneath.
|
||||
explicit, err := parseServiceEnvVars(serviceEnvVars)
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
|
||||
maps.Copy(merged, params.ServiceEnvVars)
|
||||
maps.Copy(merged, explicit) // explicit wins on conflict
|
||||
serviceEnvVars = envMapToSlice(merged)
|
||||
}
|
||||
|
||||
var resetParamsCmd = &cobra.Command{
|
||||
Use: "reset-params",
|
||||
Short: "Remove saved service install parameters",
|
||||
Long: "Removes the saved service.json file so the next install uses default parameters.",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
path := serviceParamsPath()
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
cmd.Println("No saved service parameters found")
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("remove service params: %w", err)
|
||||
}
|
||||
cmd.Printf("Removed saved service parameters (%s)\n", path)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
|
||||
func envMapToSlice(m map[string]string) []string {
|
||||
s := make([]string, 0, len(m))
|
||||
for k, v := range m {
|
||||
s = append(s, k+"="+v)
|
||||
}
|
||||
return s
|
||||
}
|
||||
523
client/cmd/service_params_test.go
Normal file
523
client/cmd/service_params_test.go
Normal file
@@ -0,0 +1,523 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
)
|
||||
|
||||
func TestServiceParamsPath(t *testing.T) {
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
|
||||
configs.StateDir = "/var/lib/netbird"
|
||||
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
|
||||
|
||||
configs.StateDir = "/custom/state"
|
||||
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
|
||||
}
|
||||
|
||||
func TestSaveAndLoadServiceParams(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
configs.StateDir = tmpDir
|
||||
|
||||
params := &serviceParams{
|
||||
LogLevel: "debug",
|
||||
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||
ManagementURL: "https://my.server.com",
|
||||
ConfigPath: "/etc/netbird/config.json",
|
||||
LogFiles: []string{"/var/log/netbird/client.log", "console"},
|
||||
DisableProfiles: true,
|
||||
DisableUpdateSettings: false,
|
||||
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
|
||||
}
|
||||
|
||||
err := saveServiceParams(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the file exists and is valid JSON.
|
||||
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, json.Valid(data))
|
||||
|
||||
loaded, err := loadServiceParams()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, loaded)
|
||||
|
||||
assert.Equal(t, params.LogLevel, loaded.LogLevel)
|
||||
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
|
||||
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
|
||||
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
|
||||
assert.Equal(t, params.LogFiles, loaded.LogFiles)
|
||||
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
|
||||
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
|
||||
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
|
||||
}
|
||||
|
||||
func TestLoadServiceParams_FileNotExists(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
configs.StateDir = tmpDir
|
||||
|
||||
params, err := loadServiceParams()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, params)
|
||||
}
|
||||
|
||||
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
original := configs.StateDir
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
configs.StateDir = tmpDir
|
||||
|
||||
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
params, err := loadServiceParams()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, params)
|
||||
}
|
||||
|
||||
func TestCurrentServiceParams(t *testing.T) {
|
||||
origLogLevel := logLevel
|
||||
origDaemonAddr := daemonAddr
|
||||
origManagementURL := managementURL
|
||||
origConfigPath := configPath
|
||||
origLogFiles := logFiles
|
||||
origProfilesDisabled := profilesDisabled
|
||||
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() {
|
||||
logLevel = origLogLevel
|
||||
daemonAddr = origDaemonAddr
|
||||
managementURL = origManagementURL
|
||||
configPath = origConfigPath
|
||||
logFiles = origLogFiles
|
||||
profilesDisabled = origProfilesDisabled
|
||||
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||
serviceEnvVars = origServiceEnvVars
|
||||
})
|
||||
|
||||
logLevel = "trace"
|
||||
daemonAddr = "tcp://127.0.0.1:9999"
|
||||
managementURL = "https://mgmt.example.com"
|
||||
configPath = "/tmp/test-config.json"
|
||||
logFiles = []string{"/tmp/test.log"}
|
||||
profilesDisabled = true
|
||||
updateSettingsDisabled = true
|
||||
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
|
||||
|
||||
params := currentServiceParams()
|
||||
|
||||
assert.Equal(t, "trace", params.LogLevel)
|
||||
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
|
||||
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
|
||||
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
|
||||
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
|
||||
assert.True(t, params.DisableProfiles)
|
||||
assert.True(t, params.DisableUpdateSettings)
|
||||
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
|
||||
origLogLevel := logLevel
|
||||
origDaemonAddr := daemonAddr
|
||||
origManagementURL := managementURL
|
||||
origConfigPath := configPath
|
||||
origLogFiles := logFiles
|
||||
origProfilesDisabled := profilesDisabled
|
||||
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() {
|
||||
logLevel = origLogLevel
|
||||
daemonAddr = origDaemonAddr
|
||||
managementURL = origManagementURL
|
||||
configPath = origConfigPath
|
||||
logFiles = origLogFiles
|
||||
profilesDisabled = origProfilesDisabled
|
||||
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||
serviceEnvVars = origServiceEnvVars
|
||||
})
|
||||
|
||||
// Reset all flags to defaults.
|
||||
logLevel = "info"
|
||||
daemonAddr = "unix:///var/run/netbird.sock"
|
||||
managementURL = ""
|
||||
configPath = "/etc/netbird/config.json"
|
||||
logFiles = []string{"/var/log/netbird/client.log"}
|
||||
profilesDisabled = false
|
||||
updateSettingsDisabled = false
|
||||
serviceEnvVars = nil
|
||||
|
||||
// Reset Changed state on all relevant flags.
|
||||
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
|
||||
// Simulate user explicitly setting --log-level via CLI.
|
||||
logLevel = "warn"
|
||||
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
|
||||
|
||||
saved := &serviceParams{
|
||||
LogLevel: "debug",
|
||||
DaemonAddr: "tcp://127.0.0.1:5555",
|
||||
ManagementURL: "https://saved.example.com",
|
||||
ConfigPath: "/saved/config.json",
|
||||
LogFiles: []string{"/saved/client.log"},
|
||||
DisableProfiles: true,
|
||||
DisableUpdateSettings: true,
|
||||
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
applyServiceParams(cmd, saved)
|
||||
|
||||
// log-level was Changed, so it should keep "warn", not use saved "debug".
|
||||
assert.Equal(t, "warn", logLevel)
|
||||
|
||||
// All other fields were not Changed, so they should use saved values.
|
||||
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
|
||||
assert.Equal(t, "https://saved.example.com", managementURL)
|
||||
assert.Equal(t, "/saved/config.json", configPath)
|
||||
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
|
||||
assert.True(t, profilesDisabled)
|
||||
assert.True(t, updateSettingsDisabled)
|
||||
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
|
||||
origProfilesDisabled := profilesDisabled
|
||||
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||
t.Cleanup(func() {
|
||||
profilesDisabled = origProfilesDisabled
|
||||
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||
})
|
||||
|
||||
// Simulate current state where booleans are true (e.g. set by previous install).
|
||||
profilesDisabled = true
|
||||
updateSettingsDisabled = true
|
||||
|
||||
// Reset Changed state so flags appear unset.
|
||||
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
|
||||
// Saved params have both as false.
|
||||
saved := &serviceParams{
|
||||
DisableProfiles: false,
|
||||
DisableUpdateSettings: false,
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
applyServiceParams(cmd, saved)
|
||||
|
||||
assert.False(t, profilesDisabled, "saved false should override current true")
|
||||
assert.False(t, updateSettingsDisabled, "saved false should override current true")
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
|
||||
origManagementURL := managementURL
|
||||
t.Cleanup(func() { managementURL = origManagementURL })
|
||||
|
||||
managementURL = "https://leftover.example.com"
|
||||
|
||||
// Simulate saved params where management URL was explicitly cleared.
|
||||
saved := &serviceParams{
|
||||
LogLevel: "info",
|
||||
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||
// ManagementURL intentionally empty: was cleared with --management-url "".
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||
f.Changed = false
|
||||
})
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
applyServiceParams(cmd, saved)
|
||||
|
||||
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
|
||||
}
|
||||
|
||||
func TestApplyServiceParams_NilParams(t *testing.T) {
|
||||
origLogLevel := logLevel
|
||||
t.Cleanup(func() { logLevel = origLogLevel })
|
||||
|
||||
logLevel = "info"
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
|
||||
// Should be a no-op.
|
||||
applyServiceParams(cmd, nil)
|
||||
assert.Equal(t, "info", logLevel)
|
||||
}
|
||||
|
||||
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||
|
||||
// Set up a command with --service-env marked as Changed.
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
|
||||
|
||||
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
|
||||
|
||||
saved := &serviceParams{
|
||||
ServiceEnvVars: map[string]string{
|
||||
"SAVED": "val",
|
||||
"OVERLAP": "saved",
|
||||
},
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, saved)
|
||||
|
||||
// Parse result for easier assertion.
|
||||
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "yes", result["EXPLICIT"])
|
||||
assert.Equal(t, "val", result["SAVED"])
|
||||
// Explicit wins on conflict.
|
||||
assert.Equal(t, "explicit", result["OVERLAP"])
|
||||
}
|
||||
|
||||
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||
|
||||
serviceEnvVars = nil
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
|
||||
saved := &serviceParams{
|
||||
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, saved)
|
||||
|
||||
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
|
||||
}
|
||||
|
||||
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
|
||||
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
|
||||
// added to serviceParams but not wired into these functions, this test fails.
|
||||
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Collect all JSON field names from the serviceParams struct.
|
||||
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
|
||||
|
||||
// Collect field names referenced in currentServiceParams and applyServiceParams.
|
||||
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
|
||||
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
|
||||
// applyServiceEnvParams handles ServiceEnvVars indirectly.
|
||||
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
|
||||
for k, v := range applyEnvFields {
|
||||
applyFields[k] = v
|
||||
}
|
||||
|
||||
for _, field := range structFields {
|
||||
assert.Contains(t, currentFields, field,
|
||||
"serviceParams field %q is not captured in currentServiceParams()", field)
|
||||
assert.Contains(t, applyFields, field,
|
||||
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
|
||||
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
|
||||
// it flows through newSVCConfig() EnvVars, not CLI args.
|
||||
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||
require.NotEmpty(t, structFields)
|
||||
|
||||
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
|
||||
fieldsNotInArgs := map[string]bool{
|
||||
"ServiceEnvVars": true,
|
||||
}
|
||||
|
||||
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
|
||||
|
||||
// Forward: every struct field must appear in buildServiceArguments.
|
||||
for _, field := range structFields {
|
||||
if fieldsNotInArgs[field] {
|
||||
continue
|
||||
}
|
||||
globalVar := fieldToGlobalVar(field)
|
||||
assert.Contains(t, buildFields, globalVar,
|
||||
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
|
||||
}
|
||||
|
||||
// Reverse: every service-related global used in buildServiceArguments must
|
||||
// have a corresponding serviceParams field. This catches a developer adding
|
||||
// a new flag to buildServiceArguments without adding it to the struct.
|
||||
globalToField := make(map[string]string, len(structFields))
|
||||
for _, field := range structFields {
|
||||
globalToField[fieldToGlobalVar(field)] = field
|
||||
}
|
||||
// Identifiers in buildServiceArguments that are not service params
|
||||
// (builtins, boilerplate, loop variables).
|
||||
nonParamGlobals := map[string]bool{
|
||||
"args": true, "append": true, "string": true, "_": true,
|
||||
"logFile": true, // range variable over logFiles
|
||||
}
|
||||
for ref := range buildFields {
|
||||
if nonParamGlobals[ref] {
|
||||
continue
|
||||
}
|
||||
_, inStruct := globalToField[ref]
|
||||
assert.True(t, inStruct,
|
||||
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
|
||||
}
|
||||
}
|
||||
|
||||
// extractStructJSONFields returns field names from a named struct type.
|
||||
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
|
||||
t.Helper()
|
||||
var fields []string
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
ts, ok := n.(*ast.TypeSpec)
|
||||
if !ok || ts.Name.Name != structName {
|
||||
return true
|
||||
}
|
||||
st, ok := ts.Type.(*ast.StructType)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, f := range st.Fields.List {
|
||||
if len(f.Names) > 0 {
|
||||
fields = append(fields, f.Names[0].Name)
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
return fields
|
||||
}
|
||||
|
||||
// extractFuncFieldRefs returns which of the given field names appear inside the
|
||||
// named function, either as selector expressions (params.FieldName) or as
|
||||
// composite literal keys (&serviceParams{FieldName: ...}).
|
||||
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
|
||||
t.Helper()
|
||||
fieldSet := make(map[string]bool, len(fields))
|
||||
for _, f := range fields {
|
||||
fieldSet[f] = true
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
fn := findFuncDecl(file, funcName)
|
||||
require.NotNil(t, fn, "function %s not found", funcName)
|
||||
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
switch v := n.(type) {
|
||||
case *ast.SelectorExpr:
|
||||
if fieldSet[v.Sel.Name] {
|
||||
found[v.Sel.Name] = true
|
||||
}
|
||||
case *ast.KeyValueExpr:
|
||||
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
|
||||
found[ident.Name] = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return found
|
||||
}
|
||||
|
||||
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
|
||||
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
|
||||
t.Helper()
|
||||
fn := findFuncDecl(file, funcName)
|
||||
require.NotNil(t, fn, "function %s not found", funcName)
|
||||
|
||||
refs := make(map[string]bool)
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
if ident, ok := n.(*ast.Ident); ok {
|
||||
refs[ident.Name] = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
return refs
|
||||
}
|
||||
|
||||
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
|
||||
for _, decl := range file.Decls {
|
||||
fn, ok := decl.(*ast.FuncDecl)
|
||||
if ok && fn.Name.Name == name {
|
||||
return fn
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fieldToGlobalVar maps serviceParams field names to the package-level variable
|
||||
// names used in buildServiceArguments and applyServiceParams.
|
||||
func fieldToGlobalVar(field string) string {
|
||||
m := map[string]string{
|
||||
"LogLevel": "logLevel",
|
||||
"DaemonAddr": "daemonAddr",
|
||||
"ManagementURL": "managementURL",
|
||||
"ConfigPath": "configPath",
|
||||
"LogFiles": "logFiles",
|
||||
"DisableProfiles": "profilesDisabled",
|
||||
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||
"ServiceEnvVars": "serviceEnvVars",
|
||||
}
|
||||
if v, ok := m[field]; ok {
|
||||
return v
|
||||
}
|
||||
// Default: lowercase first letter.
|
||||
return strings.ToLower(field[:1]) + field[1:]
|
||||
}
|
||||
|
||||
func TestEnvMapToSlice(t *testing.T) {
|
||||
m := map[string]string{"A": "1", "B": "2"}
|
||||
s := envMapToSlice(m)
|
||||
assert.Len(t, s, 2)
|
||||
assert.Contains(t, s, "A=1")
|
||||
assert.Contains(t, s, "B=2")
|
||||
}
|
||||
|
||||
func TestEnvMapToSlice_Empty(t *testing.T) {
|
||||
s := envMapToSlice(map[string]string{})
|
||||
assert.Empty(t, s)
|
||||
}
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +15,22 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMain intercepts when this test binary is run as a daemon subprocess.
|
||||
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
|
||||
// "service run ..." arguments. Since the test binary can't handle cobra CLI
|
||||
// args, it exits immediately, causing daemon -r to respawn rapidly until
|
||||
// hitting the rate limit and exiting. This makes service restart unreliable.
|
||||
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
|
||||
func TestMain(m *testing.M) {
|
||||
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
|
||||
<-sig
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
@@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) {
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
|
||||
@@ -33,14 +33,14 @@ var (
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
// PeerConnStatus is a peer's connection status.
|
||||
type PeerConnStatus = peer.ConnStatus
|
||||
|
||||
const (
|
||||
// PeerStatusConnected indicates the peer is in connected state.
|
||||
PeerStatusConnected = peer.StatusConnected
|
||||
)
|
||||
|
||||
// PeerConnStatus is a peer's connection status.
|
||||
type PeerConnStatus = peer.ConnStatus
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
@@ -375,6 +375,32 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL.
|
||||
// It returns an ExposeSession. Call Wait on the session to keep it alive.
|
||||
func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr := engine.GetExposeManager()
|
||||
if mgr == nil {
|
||||
return nil, fmt.Errorf("expose manager not available")
|
||||
}
|
||||
|
||||
resp, err := mgr.Expose(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expose: %w", err)
|
||||
}
|
||||
|
||||
return &ExposeSession{
|
||||
Domain: resp.Domain,
|
||||
ServiceName: resp.ServiceName,
|
||||
ServiceURL: resp.ServiceURL,
|
||||
mgr: mgr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
|
||||
45
client/embed/expose.go
Normal file
45
client/embed/expose.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
)
|
||||
|
||||
const (
|
||||
// ExposeProtocolHTTP exposes the service as HTTP.
|
||||
ExposeProtocolHTTP = expose.ProtocolHTTP
|
||||
// ExposeProtocolHTTPS exposes the service as HTTPS.
|
||||
ExposeProtocolHTTPS = expose.ProtocolHTTPS
|
||||
// ExposeProtocolTCP exposes the service as TCP.
|
||||
ExposeProtocolTCP = expose.ProtocolTCP
|
||||
// ExposeProtocolUDP exposes the service as UDP.
|
||||
ExposeProtocolUDP = expose.ProtocolUDP
|
||||
// ExposeProtocolTLS exposes the service as TLS.
|
||||
ExposeProtocolTLS = expose.ProtocolTLS
|
||||
)
|
||||
|
||||
// ExposeRequest is a request to expose a local service via the NetBird reverse proxy.
|
||||
type ExposeRequest = expose.Request
|
||||
|
||||
// ExposeProtocolType represents the protocol used for exposing a service.
|
||||
type ExposeProtocolType = expose.ProtocolType
|
||||
|
||||
// ExposeSession represents an active expose session. Use Wait to block until the session ends.
|
||||
type ExposeSession struct {
|
||||
Domain string
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
|
||||
mgr *expose.Manager
|
||||
}
|
||||
|
||||
// Wait blocks while keeping the expose session alive.
|
||||
// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session.
|
||||
func (s *ExposeSession) Wait(ctx context.Context) error {
|
||||
if s == nil || s.mgr == nil {
|
||||
return errors.New("expose session is not initialized")
|
||||
}
|
||||
return s.mgr.KeepAlive(ctx, s.Domain)
|
||||
}
|
||||
@@ -286,6 +286,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
@@ -43,6 +44,7 @@ const (
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
jumpNatOutput = "jump-nat-output"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
@@ -387,6 +389,14 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
|
||||
// Remove jump rules from built-in chains before deleting custom chains,
|
||||
// otherwise the chain deletion fails with "device or resource busy".
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
||||
log.Debugf("clean OUTPUT jump rule: %v", err)
|
||||
}
|
||||
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
@@ -396,6 +406,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainNATOutput, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
@@ -970,6 +981,81 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.rules[jumpNatOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
if !chainExists {
|
||||
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
||||
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
|
||||
if !chainExists {
|
||||
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
||||
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNatOutput] = jumpRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
|
||||
@@ -169,6 +169,14 @@ type Manager interface {
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
|
||||
@@ -346,6 +346,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRawOutput = "netbird-raw-out"
|
||||
chainNameRawPrerouting = "netbird-raw-pre"
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameNATOutput = "netbird-nat-output"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
@@ -1853,6 +1854,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.chains[chainNameNATOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameNATOutput,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
delete(r.chains, chainNameNATOutput)
|
||||
return fmt.Errorf("create NAT output chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameNATOutput],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
|
||||
@@ -140,6 +140,17 @@ type Manager struct {
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
|
||||
// Only one hook per protocol is supported. Outbound direction only.
|
||||
udpHookOut atomic.Pointer[packetHook]
|
||||
tcpHookOut atomic.Pointer[packetHook]
|
||||
}
|
||||
|
||||
// packetHook stores a registered hook for a specific IP:port.
|
||||
type packetHook struct {
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
fn func([]byte) bool
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -594,6 +605,8 @@ func (m *Manager) resetState() {
|
||||
maps.Clear(m.incomingRules)
|
||||
maps.Clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
m.udpHookOut.Store(nil)
|
||||
m.tcpHookOut.Store(nil)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
@@ -713,6 +726,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
}
|
||||
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||
if m.mssClampEnabled {
|
||||
@@ -895,38 +911,21 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
d.dnatOrigPort = 0
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
// Check specific destination IP first
|
||||
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check IPv4 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
if h.ip == dstIP && h.port == dport {
|
||||
return h.fn(packetData)
|
||||
}
|
||||
|
||||
// Check IPv6 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1278,12 +1277,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
// if rule has UDP hook (and if we are here we match this rule)
|
||||
// we ignore rule.drop and call this hook
|
||||
if rule.udpHook != nil {
|
||||
return rule.mgmtId, rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
@@ -1342,65 +1335,30 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
return sourceMatched
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
udpHook: hook,
|
||||
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.udpHookOut.Store(nil)
|
||||
return
|
||||
}
|
||||
|
||||
if ip.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip][r.id] = r
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
m.udpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
}
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// Check incoming hooks (stored in allow rules)
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.tcpHookOut.Store(nil)
|
||||
return
|
||||
}
|
||||
// Check outgoing hooks
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
m.tcpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
@@ -186,81 +187,52 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddUDPPacketHook(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
ip netip.Addr
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
expectedID string
|
||||
}{
|
||||
{
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
ip: netip.MustParseAddr("10.168.0.1"),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
},
|
||||
{
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
ip: netip.MustParseAddr("::1"),
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
},
|
||||
}
|
||||
func TestSetUDPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
var called bool
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
h := manager.udpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(8000), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.incomingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
if len(manager.outgoingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load())
|
||||
}
|
||||
|
||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
if tt.dPort != addedRule.dPort.Values[0] {
|
||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||
return
|
||||
}
|
||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||
return
|
||||
}
|
||||
if addedRule.udpHook == nil {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
func TestSetTCPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
var called bool
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
|
||||
h := manager.tcpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(53), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
||||
assert.Nil(t, manager.tcpHookOut.Load())
|
||||
}
|
||||
|
||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||
@@ -530,39 +502,12 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
|
||||
|
||||
if !found {
|
||||
t.Fatalf("The hook was not added properly.")
|
||||
}
|
||||
|
||||
// Now remove the packet hook
|
||||
err = manager.RemovePacketHook(hookID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to remove hook: %s", err)
|
||||
}
|
||||
|
||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
}
|
||||
}
|
||||
}
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
|
||||
}
|
||||
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
@@ -592,8 +537,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
}
|
||||
|
||||
hookCalled := false
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
manager.SetUDPPacketHook(
|
||||
netip.MustParseAddr("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
@@ -601,7 +545,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
return true
|
||||
},
|
||||
)
|
||||
require.NotEmpty(t, hookID)
|
||||
|
||||
// Create test UDP packet
|
||||
ipv4 := &layers.IPv4{
|
||||
|
||||
@@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interfaces: %v", err)
|
||||
} else {
|
||||
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
||||
// case where an interface comes up between refreshes.
|
||||
for _, intf := range interfaces {
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
}
|
||||
|
||||
@@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
@@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return fmt.Errorf("output DNAT not supported without native firewall")
|
||||
}
|
||||
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
|
||||
@@ -18,9 +18,7 @@ type PeerRule struct {
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
|
||||
udpHook func([]byte) bool
|
||||
drop bool
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
|
||||
@@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) {
|
||||
{
|
||||
name: "UDPTraffic_WithHook",
|
||||
setup: func(m *Manager) {
|
||||
hookFunc := func([]byte) bool {
|
||||
return true
|
||||
}
|
||||
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
|
||||
return true // drop (intercepted by hook)
|
||||
})
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageOutbound1to1NAT,
|
||||
StageOutboundPortReverse,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
|
||||
@@ -15,14 +15,17 @@ type PacketFilter interface {
|
||||
// FilterInbound filter incoming packets from external sources to host
|
||||
FilterInbound(packetData []byte, size int) bool
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||
// Hook function receives raw network packet data as argument.
|
||||
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one UDP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one TCP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
|
||||
@@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
// SetUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
@@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemovePacketHook mocks base method.
|
||||
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemovePacketHook indicates an expected call of RemovePacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockPacketFilter is a mock of PacketFilter interface.
|
||||
type MockPacketFilter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPacketFilterMockRecorder
|
||||
}
|
||||
|
||||
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
|
||||
type MockPacketFilterMockRecorder struct {
|
||||
mock *MockPacketFilter
|
||||
}
|
||||
|
||||
// NewMockPacketFilter creates a new mock instance.
|
||||
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
|
||||
mock := &MockPacketFilter{ctrl: ctrl}
|
||||
mock.recorder = &MockPacketFilterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterInbound indicates an expected call of FilterInbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||
}
|
||||
|
||||
// FilterOutbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||
}
|
||||
|
||||
// SetNetwork mocks base method.
|
||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||
}
|
||||
|
||||
// SetNetwork indicates an expected call of SetNetwork.
|
||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||
}
|
||||
@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow()
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
@@ -221,7 +215,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
||||
config := &PKCEAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow()
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
@@ -266,7 +254,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
||||
config := &DeviceAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||
Domain: protoConfig.Domain,
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -44,6 +44,10 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// androidRunOverride is set on Android to inject mobile dependencies
|
||||
// when using embed.Client (which calls Run() with empty MobileDependency).
|
||||
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
@@ -76,6 +80,9 @@ func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
if androidRunOverride != nil {
|
||||
return androidRunOverride(c, runningChan, logPath)
|
||||
}
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
}
|
||||
|
||||
@@ -610,12 +617,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
||||
|
||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
@@ -634,12 +635,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||
}
|
||||
|
||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||
|
||||
73
client/internal/connect_android_default.go
Normal file
73
client/internal/connect_android_default.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
|
||||
// It returns an empty interface list, which means ICE P2P candidates won't be
|
||||
// discovered — connections will fall back to relay. Applications that need P2P
|
||||
// should provide a real implementation via runOnAndroidEmbed that uses
|
||||
// Android's ConnectivityManager to enumerate network interfaces.
|
||||
type noopIFaceDiscover struct{}
|
||||
|
||||
func (noopIFaceDiscover) IFaces() (string, error) {
|
||||
// Return empty JSON array — no local interfaces advertised for ICE.
|
||||
// This is intentional: without Android's ConnectivityManager, we cannot
|
||||
// reliably enumerate interfaces (netlink is restricted on Android 11+).
|
||||
// Relay connections still work; only P2P hole-punching is disabled.
|
||||
return "[]", nil
|
||||
}
|
||||
|
||||
// noopNetworkChangeListener is a stub for embed.Client on Android.
|
||||
// Network change events are ignored since the embed client manages its own
|
||||
// reconnection logic via the engine's built-in retry mechanism.
|
||||
type noopNetworkChangeListener struct{}
|
||||
|
||||
func (noopNetworkChangeListener) OnNetworkChanged(string) {
|
||||
// No-op: embed.Client relies on the engine's internal reconnection
|
||||
// logic rather than OS-level network change notifications.
|
||||
}
|
||||
|
||||
func (noopNetworkChangeListener) SetInterfaceIP(string) {
|
||||
// No-op: in netstack mode, the overlay IP is managed by the userspace
|
||||
// network stack, not by OS-level interface configuration.
|
||||
}
|
||||
|
||||
// noopDnsReadyListener is a stub for embed.Client on Android.
|
||||
// DNS readiness notifications are not needed in netstack/embed mode
|
||||
// since system DNS is disabled and DNS resolution happens externally.
|
||||
type noopDnsReadyListener struct{}
|
||||
|
||||
func (noopDnsReadyListener) OnReady() {
|
||||
// No-op: embed.Client does not need DNS readiness notifications.
|
||||
// System DNS is disabled in netstack mode.
|
||||
}
|
||||
|
||||
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
|
||||
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
|
||||
var _ dns.ReadyListener = noopDnsReadyListener{}
|
||||
|
||||
func init() {
|
||||
// Wire up the default override so embed.Client.Start() works on Android
|
||||
// with netstack mode. Provides complete no-op stubs for all mobile
|
||||
// dependencies so the engine's existing Android code paths work unchanged.
|
||||
// Applications that need P2P ICE or real DNS should replace this by
|
||||
// setting androidRunOverride before calling Start().
|
||||
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
|
||||
return c.runOnAndroidEmbed(
|
||||
noopIFaceDiscover{},
|
||||
noopNetworkChangeListener{},
|
||||
[]netip.AddrPort{},
|
||||
noopDnsReadyListener{},
|
||||
runningChan,
|
||||
logPath,
|
||||
)
|
||||
}
|
||||
}
|
||||
32
client/internal/connect_android_embed.go
Normal file
32
client/internal/connect_android_embed.go
Normal file
@@ -0,0 +1,32 @@
|
||||
//go:build android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
|
||||
// so embed.Client.Start() can detect when the engine is ready.
|
||||
// It provides complete MobileDependency so the engine's existing
|
||||
// Android code paths work unchanged.
|
||||
func (c *ConnectClient) runOnAndroidEmbed(
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
runningChan chan struct{},
|
||||
logPath string,
|
||||
) error {
|
||||
mobileDependency := MobileDependency{
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, runningChan, logPath)
|
||||
}
|
||||
@@ -73,6 +73,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
return nil
|
||||
}
|
||||
w.response = m
|
||||
if m.MsgHdr.Truncated {
|
||||
w.SetMeta("truncated", "true")
|
||||
}
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
@@ -195,10 +198,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
startTime := time.Now()
|
||||
requestID := resutil.GenerateRequestID()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": requestID,
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
question := r.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
@@ -261,9 +268,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
meta += " " + k + "=" + v
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
|
||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||
meta, time.Since(startTime))
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
|
||||
@@ -85,6 +85,16 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
|
||||
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// SetFirewall mock implementation of SetFirewall from Server interface
|
||||
func (m *MockServer) SetFirewall(Firewall) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||
func (m *MockServer) BeginBatch() {
|
||||
// Mock implementation - no-op
|
||||
|
||||
@@ -104,3 +104,23 @@ func (r *responseWriter) TsigTimersOnly(bool) {
|
||||
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
||||
func (r *responseWriter) Hijack() {
|
||||
}
|
||||
|
||||
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
|
||||
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
|
||||
var srcIP net.IP
|
||||
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
|
||||
srcIP = ipv4.(*layers.IPv4).SrcIP
|
||||
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
|
||||
srcIP = ipv6.(*layers.IPv6).SrcIP
|
||||
}
|
||||
|
||||
var srcPort int
|
||||
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
|
||||
srcPort = int(udp.(*layers.UDP).SrcPort)
|
||||
}
|
||||
|
||||
if srcIP == nil {
|
||||
return nil
|
||||
}
|
||||
return &net.UDPAddr{IP: srcIP, Port: srcPort}
|
||||
}
|
||||
|
||||
@@ -57,6 +57,8 @@ type Server interface {
|
||||
ProbeAvailability()
|
||||
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
||||
PopulateManagementDomain(mgmtURL *url.URL) error
|
||||
SetRouteChecker(func(netip.Addr) bool)
|
||||
SetFirewall(Firewall)
|
||||
}
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
@@ -104,6 +106,7 @@ type DefaultServer struct {
|
||||
|
||||
statusRecorder *peer.Status
|
||||
stateManager *statemanager.Manager
|
||||
routeMatch func(netip.Addr) bool
|
||||
|
||||
probeMu sync.Mutex
|
||||
probeCancel context.CancelFunc
|
||||
@@ -149,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
|
||||
if config.WgInterface.IsUserspaceBind() {
|
||||
dnsService = NewServiceViaMemory(config.WgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort)
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
|
||||
}
|
||||
|
||||
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
||||
@@ -229,6 +232,14 @@ func newDefaultServer(
|
||||
return defaultServer
|
||||
}
|
||||
|
||||
// SetRouteChecker sets the function used by upstream resolvers to determine
|
||||
// whether an IP is routed through the tunnel.
|
||||
func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.routeMatch = f
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for the given domains with the given priority.
|
||||
// Any previously registered handler for the same domain and priority will be replaced.
|
||||
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||
@@ -364,6 +375,17 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
return s.service.RuntimeIP()
|
||||
}
|
||||
|
||||
// SetFirewall sets the firewall used for DNS port DNAT rules.
|
||||
// This must be called before Initialize when using the listener-based service,
|
||||
// because the firewall is typically not available at construction time.
|
||||
func (s *DefaultServer) SetFirewall(fw Firewall) {
|
||||
if svc, ok := s.service.(*serviceViaListener); ok {
|
||||
svc.listenerFlagLock.Lock()
|
||||
svc.firewall = fw
|
||||
svc.listenerFlagLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.probeMu.Lock()
|
||||
@@ -385,8 +407,12 @@ func (s *DefaultServer) Stop() {
|
||||
maps.Clear(s.extraDomains)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) disableDNS() error {
|
||||
defer s.service.Stop()
|
||||
func (s *DefaultServer) disableDNS() (retErr error) {
|
||||
defer func() {
|
||||
if err := s.service.Stop(); err != nil {
|
||||
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
if s.isUsingNoopHostManager() {
|
||||
return nil
|
||||
@@ -743,6 +769,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
|
||||
return
|
||||
}
|
||||
handler.routeMatch = s.routeMatch
|
||||
|
||||
for _, ns := range originalNameservers {
|
||||
if ns == config.ServerIP {
|
||||
@@ -852,6 +879,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create upstream resolver: %v", err)
|
||||
}
|
||||
handler.routeMatch = s.routeMatch
|
||||
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
@@ -1036,6 +1064,7 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
return
|
||||
}
|
||||
handler.routeMatch = s.routeMatch
|
||||
|
||||
handler.upstreamServers = maps.Keys(hostDNSServers)
|
||||
handler.deactivate = func(error) {}
|
||||
|
||||
@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
@@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand
|
||||
type mockService struct{}
|
||||
|
||||
func (m *mockService) Listen() error { return nil }
|
||||
func (m *mockService) Stop() {}
|
||||
func (m *mockService) Stop() error { return nil }
|
||||
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
|
||||
func (m *mockService) RuntimePort() int { return 53 }
|
||||
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
|
||||
@@ -4,15 +4,25 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPort = 53
|
||||
)
|
||||
|
||||
// Firewall provides DNAT capabilities for DNS port redirection.
|
||||
// This is used when the DNS server cannot bind port 53 directly
|
||||
// and needs firewall rules to redirect traffic.
|
||||
type Firewall interface {
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
}
|
||||
|
||||
type service interface {
|
||||
Listen() error
|
||||
Stop()
|
||||
Stop() error
|
||||
RegisterMux(domain string, handler dns.Handler)
|
||||
DeregisterMux(key string)
|
||||
RuntimePort() int
|
||||
|
||||
@@ -10,9 +10,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
@@ -31,25 +35,33 @@ type serviceViaListener struct {
|
||||
dnsMux *dns.ServeMux
|
||||
customAddr *netip.AddrPort
|
||||
server *dns.Server
|
||||
tcpServer *dns.Server
|
||||
listenIP netip.Addr
|
||||
listenPort uint16
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
ebpfService ebpfMgr.Manager
|
||||
firewall Firewall
|
||||
tcpDNATConfigured bool
|
||||
}
|
||||
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
s := &serviceViaListener{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: mux,
|
||||
customAddr: customAddr,
|
||||
firewall: fw,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
tcpServer: &dns.Server{
|
||||
Net: "tcp",
|
||||
Handler: mux,
|
||||
},
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -70,43 +82,86 @@ func (s *serviceViaListener) Listen() error {
|
||||
return fmt.Errorf("eval listen address: %w", err)
|
||||
}
|
||||
s.listenIP = s.listenIP.Unmap()
|
||||
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
s.server.Addr = addr
|
||||
s.tcpServer.Addr = addr
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
|
||||
log.Debugf("starting dns on %s (UDP + TCP)", addr)
|
||||
s.listenerIsRunning = true
|
||||
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
|
||||
s.listenerFlagLock.Lock()
|
||||
unexpected := s.listenerIsRunning
|
||||
s.listenerIsRunning = false
|
||||
s.listenerFlagLock.Unlock()
|
||||
|
||||
if unexpected {
|
||||
if err := s.tcpServer.Shutdown(); err != nil {
|
||||
log.Debugf("failed to shutdown DNS TCP server: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err := s.tcpServer.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
|
||||
// a DNAT rule because eBPF only handles UDP.
|
||||
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
|
||||
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
|
||||
} else {
|
||||
s.tcpDNATConfigured = true
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Stop() {
|
||||
func (s *serviceViaListener) Stop() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
s.listenerIsRunning = false
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := s.server.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
|
||||
}
|
||||
|
||||
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
|
||||
}
|
||||
|
||||
if s.tcpDNATConfigured && s.firewall != nil {
|
||||
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
s.tcpDNATConfigured = false
|
||||
}
|
||||
|
||||
if s.ebpfService != nil {
|
||||
err = s.ebpfService.FreeDNSFwd()
|
||||
if err != nil {
|
||||
log.Errorf("stopping traffic forwarder returned an error: %v", err)
|
||||
if err := s.ebpfService.FreeDNSFwd(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -133,12 +188,6 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
|
||||
return s.listenIP
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// evalListenAddress figure out the listen address for the DNS server
|
||||
// first check the 53 port availability on WG interface or lo, if not success
|
||||
@@ -187,18 +236,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
||||
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(port))
|
||||
|
||||
udpAddr := net.UDPAddrFromAddrPort(addrPort)
|
||||
udpLn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
|
||||
return false
|
||||
}
|
||||
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
if err := udpLn.Close(); err != nil {
|
||||
log.Debugf("close UDP probe listener: %s", err)
|
||||
}
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
|
||||
return false
|
||||
}
|
||||
if err := tcpLn.Close(); err != nil {
|
||||
log.Debugf("close TCP probe listener: %s", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
86
client/internal/dns/service_listener_test.go
Normal file
86
client/internal/dns/service_listener_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("192.0.2.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Create a service using a custom address to avoid needing root
|
||||
svc := newServiceViaListener(nil, nil, nil)
|
||||
svc.dnsMux.Handle(".", handler)
|
||||
|
||||
// Bind both transports up front to avoid TOCTOU races.
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
t.Skip("cannot bind to 127.0.0.153, skipping")
|
||||
}
|
||||
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
t.Skip("cannot bind TCP on same port, skipping")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", customIP, port)
|
||||
svc.server.PacketConn = udpConn
|
||||
svc.tcpServer.Listener = tcpLn
|
||||
svc.listenIP = customIP
|
||||
svc.listenPort = port
|
||||
|
||||
go func() {
|
||||
if err := svc.server.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := svc.tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
svc.listenerIsRunning = true
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, svc.Stop())
|
||||
}()
|
||||
|
||||
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Test UDP query
|
||||
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
|
||||
udpResp, _, err := udpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "UDP query should succeed")
|
||||
require.NotNil(t, udpResp)
|
||||
require.NotEmpty(t, udpResp.Answer)
|
||||
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
|
||||
|
||||
// Test TCP query
|
||||
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
|
||||
tcpResp, _, err := tcpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "TCP query should succeed")
|
||||
require.NotNil(t, tcpResp)
|
||||
require.NotEmpty(t, tcpResp.Answer)
|
||||
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
@@ -18,7 +20,8 @@ type ServiceViaMemory struct {
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP netip.Addr
|
||||
runtimePort int
|
||||
udpFilterHookID string
|
||||
tcpDNS *tcpDNSServer
|
||||
tcpHookSet bool
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
@@ -28,14 +31,13 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
if err != nil {
|
||||
log.Errorf("get last ip from network: %v", err)
|
||||
}
|
||||
s := &ServiceViaMemory{
|
||||
|
||||
return &ServiceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: lastIP,
|
||||
runtimePort: DefaultPort,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Listen() error {
|
||||
@@ -46,10 +48,8 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||
if err != nil {
|
||||
return fmt.Errorf("filter dns traffice: %w", err)
|
||||
if err := s.filterDNSTraffic(); err != nil {
|
||||
return fmt.Errorf("filter dns traffic: %w", err)
|
||||
}
|
||||
s.listenerIsRunning = true
|
||||
|
||||
@@ -57,19 +57,29 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Stop() {
|
||||
func (s *ServiceViaMemory) Stop() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter != nil {
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
if s.tcpHookSet {
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
}
|
||||
}
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
s.tcpDNS.Stop()
|
||||
}
|
||||
|
||||
s.listenerIsRunning = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -88,10 +98,18 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() error {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
return errors.New("DNS filter not initialized")
|
||||
}
|
||||
|
||||
// Create TCP DNS server lazily here since the device may not exist at construction time.
|
||||
if s.tcpDNS == nil {
|
||||
if dev := s.wgInterface.GetDevice(); dev != nil {
|
||||
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
|
||||
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
|
||||
}
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
@@ -100,12 +118,16 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
if udpLayer == nil {
|
||||
return true
|
||||
}
|
||||
udp, ok := udpLayer.(*layers.UDP)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
@@ -113,13 +135,30 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
return true
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
dev := s.wgInterface.GetDevice()
|
||||
if dev == nil {
|
||||
return true
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
|
||||
writer := &responseWriter{
|
||||
remote: remoteAddrFromPacket(packet),
|
||||
packet: packet,
|
||||
device: dev.Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
tcpHook := func(packetData []byte) bool {
|
||||
s.tcpDNS.InjectPacket(packetData)
|
||||
return true
|
||||
}
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
|
||||
s.tcpHookSet = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
444
client/internal/dns/tcpstack.go
Normal file
444
client/internal/dns/tcpstack.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTCPReceiveWindow = 8192
|
||||
dnsTCPMaxInFlight = 16
|
||||
dnsTCPIdleTimeout = 30 * time.Second
|
||||
dnsTCPReadTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
|
||||
// It is started lazily when a truncated DNS response is detected and shuts down
|
||||
// after a period of inactivity to conserve resources.
|
||||
type tcpDNSServer struct {
|
||||
mu sync.Mutex
|
||||
s *stack.Stack
|
||||
ep *dnsEndpoint
|
||||
mux *dns.ServeMux
|
||||
tunDev tun.Device
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
mtu uint16
|
||||
|
||||
running bool
|
||||
closed bool
|
||||
timerID uint64
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
|
||||
return &tcpDNSServer{
|
||||
mux: mux,
|
||||
tunDev: tunDev,
|
||||
ip: ip,
|
||||
port: port,
|
||||
mtu: mtu,
|
||||
}
|
||||
}
|
||||
|
||||
// InjectPacket ensures the stack is running and delivers a raw IP packet into
|
||||
// the gvisor stack for TCP processing. Combining both operations under a single
|
||||
// lock prevents a race where the idle timer could stop the stack between
|
||||
// start and delivery.
|
||||
func (t *tcpDNSServer) InjectPacket(payload []byte) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if !t.running {
|
||||
if err := t.startLocked(); err != nil {
|
||||
log.Errorf("failed to start TCP DNS stack: %v", err)
|
||||
return
|
||||
}
|
||||
t.running = true
|
||||
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
|
||||
}
|
||||
t.resetTimerLocked()
|
||||
|
||||
ep := t.ep
|
||||
if ep == nil || ep.dispatcher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
|
||||
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
}
|
||||
|
||||
// Stop tears down the gvisor stack and releases resources permanently.
|
||||
// After Stop, InjectPacket becomes a no-op.
|
||||
func (t *tcpDNSServer) Stop() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.stopLocked()
|
||||
t.closed = true
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) startLocked() error {
|
||||
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
nicID := tcpip.NICID(1)
|
||||
ep := &dnsEndpoint{
|
||||
tunDev: t.tunDev,
|
||||
}
|
||||
ep.mtu.Store(uint32(t.mtu))
|
||||
|
||||
if err := s.CreateNIC(nicID, ep); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create NIC: %v", err)
|
||||
}
|
||||
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
|
||||
PrefixLen: 32,
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("add protocol address: %s", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set spoofing: %s", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
)
|
||||
if err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create default subnet: %w", err)
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{Destination: defaultSubnet, NIC: nicID},
|
||||
})
|
||||
|
||||
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
|
||||
t.handleTCPDNS(r)
|
||||
})
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
|
||||
|
||||
t.s = s
|
||||
t.ep = ep
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) stopLocked() {
|
||||
if !t.running {
|
||||
return
|
||||
}
|
||||
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
t.timer = nil
|
||||
}
|
||||
|
||||
if t.s != nil {
|
||||
t.s.Close()
|
||||
t.s.Wait()
|
||||
t.s = nil
|
||||
}
|
||||
t.ep = nil
|
||||
t.running = false
|
||||
|
||||
log.Debugf("TCP DNS stack stopped")
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) resetTimerLocked() {
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
}
|
||||
t.timerID++
|
||||
id := t.timerID
|
||||
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Only stop if this timer is still the active one.
|
||||
// A racing InjectPacket may have replaced it.
|
||||
if t.timerID != id {
|
||||
return
|
||||
}
|
||||
t.stopLocked()
|
||||
})
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
wq := waiter.Queue{}
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
|
||||
conn := gonet.NewTCPConn(&wq, ep)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Tracef("TCP DNS: close conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Reset idle timer on activity
|
||||
t.mu.Lock()
|
||||
t.resetTimerLocked()
|
||||
t.mu.Unlock()
|
||||
|
||||
localAddr := &net.TCPAddr{
|
||||
IP: id.LocalAddress.AsSlice(),
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
remoteAddr := &net.TCPAddr{
|
||||
IP: id.RemoteAddress.AsSlice(),
|
||||
Port: int(id.RemotePort),
|
||||
}
|
||||
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
|
||||
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
|
||||
break
|
||||
}
|
||||
|
||||
msg, err := readTCPDNSMessage(conn)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
writer := &tcpResponseWriter{
|
||||
conn: conn,
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
t.mux.ServeDNS(writer, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
|
||||
type dnsEndpoint struct {
|
||||
dispatcher stack.NetworkDispatcher
|
||||
tunDev tun.Device
|
||||
mtu atomic.Uint32
|
||||
}
|
||||
|
||||
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
|
||||
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
|
||||
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
|
||||
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
|
||||
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
|
||||
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
|
||||
func (e *dnsEndpoint) Wait() { /* no async work */ }
|
||||
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
|
||||
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
|
||||
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
|
||||
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
|
||||
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
|
||||
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
|
||||
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
|
||||
|
||||
const tunPacketOffset = 40
|
||||
|
||||
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
raw := data.AsSlice()
|
||||
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
|
||||
buf = append(buf, raw...)
|
||||
data.Release()
|
||||
|
||||
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
|
||||
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
|
||||
continue
|
||||
}
|
||||
written++
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
|
||||
type tcpResponseWriter struct {
|
||||
conn *gonet.TCPConn
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) LocalAddr() net.Addr {
|
||||
return w.localAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
|
||||
return w.remoteAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pack: %w", err)
|
||||
}
|
||||
|
||||
// DNS TCP: 2-byte length prefix + message
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
|
||||
if _, err = w.conn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
if _, err := w.conn.Write(buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Close() error {
|
||||
return w.conn.Close()
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) TsigStatus() error { return nil }
|
||||
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
|
||||
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
|
||||
|
||||
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
|
||||
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
|
||||
// DNS over TCP uses a 2-byte length prefix
|
||||
lenBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return nil, fmt.Errorf("read length: %w", err)
|
||||
}
|
||||
|
||||
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if msgLen == 0 || msgLen > 65535 {
|
||||
return nil, fmt.Errorf("invalid message length: %d", msgLen)
|
||||
}
|
||||
|
||||
msgBuf := make([]byte, msgLen)
|
||||
if _, err := io.ReadFull(conn, msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("read message: %w", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("unpack: %w", err)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
|
||||
// Supports both IPv4 and IPv6.
|
||||
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
|
||||
if len(pkt) == 0 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcIP, transportOffset := srcIPFromPacket(pkt)
|
||||
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
|
||||
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
|
||||
}
|
||||
|
||||
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
|
||||
switch header.IPVersion(pkt) {
|
||||
case 4:
|
||||
return srcIPv4(pkt)
|
||||
case 6:
|
||||
return srcIPv6(pkt)
|
||||
default:
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
}
|
||||
|
||||
func srcIPv4(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv4MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv4(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, int(hdr.HeaderLength())
|
||||
}
|
||||
|
||||
func srcIPv6(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv6MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv6(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, header.IPv6MinimumSize
|
||||
}
|
||||
@@ -41,10 +41,61 @@ const (
|
||||
|
||||
reactivatePeriod = 30 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
|
||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
||||
// payload from the tunnel MTU.
|
||||
ipUDPHeaderSize = 60 + 8
|
||||
)
|
||||
|
||||
const testRecord = "com."
|
||||
|
||||
const (
|
||||
protoUDP = "udp"
|
||||
protoTCP = "tcp"
|
||||
)
|
||||
|
||||
type dnsProtocolKey struct{}
|
||||
|
||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
||||
}
|
||||
|
||||
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
|
||||
func dnsProtocolFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
r := &upstreamProtocolResult{}
|
||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||
}
|
||||
|
||||
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
|
||||
func setUpstreamProtocol(ctx context.Context, protocol string) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
|
||||
r.protocol = protocol
|
||||
}
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
@@ -70,6 +121,7 @@ type upstreamResolverBase struct {
|
||||
deactivate func(error)
|
||||
reactivate func()
|
||||
statusRecorder *peer.Status
|
||||
routeMatch func(netip.Addr) bool
|
||||
}
|
||||
|
||||
type upstreamFailure struct {
|
||||
@@ -137,7 +189,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
ok, failures := u.tryUpstreamServers(w, r, logger)
|
||||
// Propagate inbound protocol so upstream exchange can use TCP directly
|
||||
// when the request came in over TCP.
|
||||
ctx := u.ctx
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
network := addr.Network()
|
||||
ctx = contextWithDNSProtocol(ctx, network)
|
||||
resutil.SetMeta(w, "protocol", network)
|
||||
}
|
||||
|
||||
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
|
||||
if len(failures) > 0 {
|
||||
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
||||
}
|
||||
@@ -152,7 +213,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
@@ -167,7 +228,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
|
||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
@@ -177,15 +238,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
}
|
||||
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
|
||||
var startTime time.Time
|
||||
var upstreamProto *upstreamProtocolResult
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
}()
|
||||
@@ -202,7 +265,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
}
|
||||
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -219,10 +282,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
|
||||
resutil.SetMeta(w, "upstream", upstream.String())
|
||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
||||
}
|
||||
|
||||
// Clear Zero bit from external responses to prevent upstream servers from
|
||||
// manipulating our internal fallthrough signaling mechanism
|
||||
@@ -427,13 +493,42 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC
|
||||
return err
|
||||
}
|
||||
|
||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
||||
func clientUDPMaxSize(r *dns.Msg) int {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
return int(opt.UDPSize())
|
||||
}
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// MTU - ip + udp headers
|
||||
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
|
||||
client.UDPSize = uint16(currentMTU - (60 + 8))
|
||||
// If the request came in over TCP, go straight to TCP upstream.
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
// Note: the query could be sent out on an interface that is not ours,
|
||||
// but higher MTU settings could break truncation handling.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
client.UDPSize = maxUDPPayload
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
@@ -452,25 +547,32 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
}
|
||||
|
||||
if rm == nil || !rm.MsgHdr.Truncated {
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
// TODO: if the upstream's truncated UDP response already contains more
|
||||
// data than the client's buffer, we could truncate locally and skip
|
||||
// the TCP retry.
|
||||
|
||||
client.Net = "tcp"
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
|
||||
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
|
||||
return rm, t, nil
|
||||
}
|
||||
@@ -478,18 +580,46 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||
// If request came in over TCP, go straight to TCP upstream
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than what we can read over UDP.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If response is truncated, retry with TCP
|
||||
if reply != nil && reply.MsgHdr.Truncated {
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
@@ -510,7 +640,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
|
||||
}
|
||||
}
|
||||
|
||||
dnsConn := &dns.Conn{Conn: conn}
|
||||
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
|
||||
|
||||
if err := dnsConn.WriteMsg(r); err != nil {
|
||||
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||
|
||||
@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
||||
upstreamExchangeClient := &dns.Client{
|
||||
Timeout: ClientTimeout,
|
||||
}
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
}
|
||||
|
||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||
|
||||
@@ -65,11 +65,13 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
} else {
|
||||
upstreamIP = upstreamIP.Unmap()
|
||||
}
|
||||
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
||||
log.Debugf("using private client to query upstream: %s", upstream)
|
||||
needsPrivate := u.lNet.Contains(upstreamIP) ||
|
||||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
||||
if needsPrivate {
|
||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
|
||||
return nil, 0, fmt.Errorf("create private client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -475,3 +475,298 @@ func TestFormatFailures(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSProtocolContext(t *testing.T) {
|
||||
t.Run("roundtrip udp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
|
||||
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("roundtrip tcp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("missing returns empty", func(t *testing.T) {
|
||||
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContext(t *testing.T) {
|
||||
// Start a local DNS server that responds on TCP only
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Addr: "127.0.0.1:0",
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
tcpServer.Listener = tcpLn
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// With TCP context, should connect directly via TCP without trying UDP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
|
||||
// UDP handler returns a truncated response to trigger TCP retry.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// TCP handler returns the full answer.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.3"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{
|
||||
PacketConn: udpPC,
|
||||
Net: "udp",
|
||||
Handler: udpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := udpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
|
||||
assert.False(t, rm.Truncated, "TCP response should not be truncated")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
|
||||
// Start only a TCP server (no UDP). With TCP context it should succeed.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.2"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// TCP context: should skip UDP entirely and go directly to TCP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
|
||||
|
||||
// Without TCP context, trying to reach a TCP-only server via UDP should fail
|
||||
ctx2 := context.Background()
|
||||
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
|
||||
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
|
||||
assert.Error(t, err, "should fail when no UDP server and no TCP context")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||
// capped in the outgoing request so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
var receivedUDPSize uint16
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
receivedUDPSize = opt.UDPSize()
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() { _ = udpServer.Shutdown() })
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
|
||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
||||
"upstream should see capped EDNS0, not the client's 4096")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
|
||||
// When the client advertises a large EDNS0 (4096) and the upstream
|
||||
// truncates, the TCP response should NOT be truncated since the full
|
||||
// answer fits within the client's original buffer.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
// Add enough records to exceed MTU but fit within 4096
|
||||
for i := range 20 {
|
||||
m.Answer = append(m.Answer, &dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
|
||||
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
|
||||
})
|
||||
}
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
|
||||
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
go func() { _ = tcpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
|
||||
// Client with large buffer: should get all records without truncation
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
|
||||
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
|
||||
|
||||
// Client with small buffer: should get truncated response
|
||||
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r2.SetEdns0(512, false)
|
||||
|
||||
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm2)
|
||||
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
|
||||
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
|
||||
}
|
||||
|
||||
@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
|
||||
}
|
||||
|
||||
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||
@@ -263,20 +263,28 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
f.handleDNSQuery(logger, w, query, startTime)
|
||||
}
|
||||
|
||||
@@ -499,6 +499,17 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
||||
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
|
||||
for _, r := range routes {
|
||||
if r.Network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
e.close()
|
||||
@@ -510,6 +521,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return err
|
||||
}
|
||||
|
||||
// Inject firewall into DNS server now that it's available.
|
||||
// The DNS server is created before the firewall because the route manager
|
||||
// depends on the DNS server, and the firewall depends on the wg interface.
|
||||
e.dnsServer.SetFirewall(e.firewall)
|
||||
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
@@ -1826,6 +1842,11 @@ func (e *Engine) GetExposeManager() *expose.Manager {
|
||||
return e.exposeManager
|
||||
}
|
||||
|
||||
// IsBlockInbound returns whether inbound connections are blocked.
|
||||
func (e *Engine) IsBlockInbound() bool {
|
||||
return e.config.BlockInbound
|
||||
}
|
||||
|
||||
// GetClientMetrics returns the client metrics
|
||||
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
|
||||
@@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
@@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
@@ -1538,13 +1538,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey, err := mgmtClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1566,7 +1561,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
|
||||
@@ -4,11 +4,14 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
const renewTimeout = 10 * time.Second
|
||||
const (
|
||||
renewTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// Response holds the response from exposing a service.
|
||||
type Response struct {
|
||||
@@ -18,11 +21,13 @@ type Response struct {
|
||||
PortAutoAssigned bool
|
||||
}
|
||||
|
||||
// Request holds the parameters for exposing a local service via the management server.
|
||||
// It is part of the embed API surface and exposed via a type alias.
|
||||
type Request struct {
|
||||
NamePrefix string
|
||||
Domain string
|
||||
Port uint16
|
||||
Protocol int
|
||||
Protocol ProtocolType
|
||||
Pin string
|
||||
Password string
|
||||
UserGroups []string
|
||||
@@ -59,6 +64,8 @@ func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
|
||||
return fromClientExposeResponse(resp), nil
|
||||
}
|
||||
|
||||
// KeepAlive periodically renews the expose session for the given domain until the context is canceled or an error occurs.
|
||||
// It is part of the embed API surface and exposed via a type alias.
|
||||
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -86,7 +86,7 @@ func TestNewRequest(t *testing.T) {
|
||||
exposeReq := NewRequest(req)
|
||||
|
||||
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
|
||||
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
|
||||
assert.Equal(t, ProtocolType(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
|
||||
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
|
||||
assert.Equal(t, "secret", exposeReq.Password, "password should match")
|
||||
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
|
||||
|
||||
40
client/internal/expose/protocol.go
Normal file
40
client/internal/expose/protocol.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProtocolType represents the protocol used for exposing a service.
|
||||
type ProtocolType int
|
||||
|
||||
const (
|
||||
// ProtocolHTTP exposes the service as HTTP.
|
||||
ProtocolHTTP ProtocolType = 0
|
||||
// ProtocolHTTPS exposes the service as HTTPS.
|
||||
ProtocolHTTPS ProtocolType = 1
|
||||
// ProtocolTCP exposes the service as TCP.
|
||||
ProtocolTCP ProtocolType = 2
|
||||
// ProtocolUDP exposes the service as UDP.
|
||||
ProtocolUDP ProtocolType = 3
|
||||
// ProtocolTLS exposes the service as TLS.
|
||||
ProtocolTLS ProtocolType = 4
|
||||
)
|
||||
|
||||
// ParseProtocolType parses a protocol string into a ProtocolType.
|
||||
func ParseProtocolType(s string) (ProtocolType, error) {
|
||||
switch strings.ToLower(s) {
|
||||
case "http":
|
||||
return ProtocolHTTP, nil
|
||||
case "https":
|
||||
return ProtocolHTTPS, nil
|
||||
case "tcp":
|
||||
return ProtocolTCP, nil
|
||||
case "udp":
|
||||
return ProtocolUDP, nil
|
||||
case "tls":
|
||||
return ProtocolTLS, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", s)
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
|
||||
return &Request{
|
||||
Port: uint16(req.Port),
|
||||
Protocol: int(req.Protocol),
|
||||
Protocol: ProtocolType(req.Protocol),
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
@@ -24,7 +24,7 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest {
|
||||
NamePrefix: req.NamePrefix,
|
||||
Domain: req.Domain,
|
||||
Port: req.Port,
|
||||
Protocol: req.Protocol,
|
||||
Protocol: int(req.Protocol),
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
|
||||
@@ -39,6 +39,18 @@ const (
|
||||
DefaultAdminURL = "https://app.netbird.io:443"
|
||||
)
|
||||
|
||||
// mgmProber is the subset of management client needed for URL migration probes.
|
||||
type mgmProber interface {
|
||||
HealthCheck() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// newMgmProber creates a management client for probing URL reachability.
|
||||
// Overridden in tests to avoid real network calls.
|
||||
var newMgmProber = func(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (mgmProber, error) {
|
||||
return mgm.NewClient(ctx, addr, key, tlsEnabled)
|
||||
}
|
||||
|
||||
var DefaultInterfaceBlacklist = []string{
|
||||
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
@@ -753,21 +765,19 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
|
||||
return config, err
|
||||
}
|
||||
|
||||
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||
client, err := newMgmProber(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
defer func() {
|
||||
err = client.Close()
|
||||
if err != nil {
|
||||
if err := client.Close(); err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
if err = client.HealthCheck(); err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -10,12 +10,21 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
type mockMgmProber struct{}
|
||||
|
||||
func (m *mockMgmProber) HealthCheck() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMgmProber) Close() error { return nil }
|
||||
|
||||
func TestGetConfig(t *testing.T) {
|
||||
// case 1: new default config has to be generated
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
@@ -234,6 +243,12 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdateOldManagementURL(t *testing.T) {
|
||||
origProber := newMgmProber
|
||||
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
||||
return &mockMgmProber{}, nil
|
||||
}
|
||||
t.Cleanup(func() { newMgmProber = origProber })
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
previousManagementURL string
|
||||
@@ -273,18 +288,17 @@ func TestUpdateOldManagementURL(t *testing.T) {
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
require.NoError(t, err, "failed to create testing config")
|
||||
previousStats, err := os.Stat(configPath)
|
||||
require.NoError(t, err, "failed to create testing config stats")
|
||||
previousContent, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err, "failed to read initial config")
|
||||
resultConfig, err := UpdateOldManagementURL(context.TODO(), config, configPath)
|
||||
require.NoError(t, err, "got error when updating old management url")
|
||||
require.Equal(t, tt.expectedManagementURL, resultConfig.ManagementURL.String())
|
||||
newStats, err := os.Stat(configPath)
|
||||
require.NoError(t, err, "failed to create testing config stats")
|
||||
switch tt.fileShouldNotChange {
|
||||
case true:
|
||||
require.Equal(t, previousStats.ModTime(), newStats.ModTime(), "file should not change")
|
||||
case false:
|
||||
require.NotEqual(t, previousStats.ModTime(), newStats.ModTime(), "file should have changed")
|
||||
newContent, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err, "failed to read updated config")
|
||||
if tt.fileShouldNotChange {
|
||||
require.Equal(t, string(previousContent), string(newContent), "file should not change")
|
||||
} else {
|
||||
require.NotEqual(t, string(previousContent), string(newContent), "file should have changed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -52,6 +52,7 @@ type Manager interface {
|
||||
TriggerSelection(route.HAMap)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
GetClientRoutes() route.HAMap
|
||||
GetSelectedClientRoutes() route.HAMap
|
||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
@@ -465,6 +466,16 @@ func (m *DefaultManager) GetClientRoutes() route.HAMap {
|
||||
return maps.Clone(m.clientRoutes)
|
||||
}
|
||||
|
||||
// GetSelectedClientRoutes returns only the currently selected/active client routes,
|
||||
// filtering out deselected exit nodes. Use this instead of GetClientRoutes when checking
|
||||
// if traffic should be routed through the tunnel.
|
||||
func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
m.mux.Lock()
|
||||
|
||||
@@ -18,6 +18,7 @@ type MockManager struct {
|
||||
TriggerSelectionFunc func(haMap route.HAMap)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
GetClientRoutesFunc func() route.HAMap
|
||||
GetSelectedClientRoutesFunc func() route.HAMap
|
||||
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
}
|
||||
@@ -61,7 +62,7 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
|
||||
// GetClientRoutes mock implementation of GetClientRoutes from the Manager interface
|
||||
func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||
if m.GetClientRoutesFunc != nil {
|
||||
return m.GetClientRoutesFunc()
|
||||
@@ -69,6 +70,14 @@ func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSelectedClientRoutes mock implementation of GetSelectedClientRoutes from the Manager interface
|
||||
func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
|
||||
if m.GetSelectedClientRoutesFunc != nil {
|
||||
return m.GetSelectedClientRoutesFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||
|
||||
@@ -31,26 +31,11 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
|
||||
// and a separate comparison set (without fake IP blocks) for diff detection.
|
||||
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
||||
// initialRoutes contains fake IP block for interface configuration
|
||||
filteredInitial := make([]*route.Route, 0)
|
||||
for _, r := range initialRoutes {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
filteredInitial = append(filteredInitial, r)
|
||||
}
|
||||
n.initialRoutes = filteredInitial
|
||||
|
||||
// routesForComparison excludes fake IP block for comparison with new routes
|
||||
filteredComparison := make([]*route.Route, 0)
|
||||
for _, r := range routesForComparison {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
filteredComparison = append(filteredComparison, r)
|
||||
}
|
||||
n.currentRoutes = filteredComparison
|
||||
n.initialRoutes = filterStatic(initialRoutes)
|
||||
n.currentRoutes = filterStatic(routesForComparison)
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
@@ -83,13 +68,43 @@ func (n *Notifier) notify() {
|
||||
return
|
||||
}
|
||||
|
||||
routeStrings := n.routesToStrings(n.currentRoutes)
|
||||
allRoutes := slices.Clone(n.currentRoutes)
|
||||
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
|
||||
|
||||
routeStrings := n.routesToStrings(allRoutes)
|
||||
sort.Strings(routeStrings)
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ","))
|
||||
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, allRoutes), ","))
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
// extraInitialRoutes returns initialRoutes whose network prefix is absent
|
||||
// from currentRoutes (e.g. the fake IP block added at setup time).
|
||||
func (n *Notifier) extraInitialRoutes() []*route.Route {
|
||||
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
|
||||
for _, r := range n.currentRoutes {
|
||||
currentNets[r.Network] = struct{}{}
|
||||
}
|
||||
|
||||
var extra []*route.Route
|
||||
for _, r := range n.initialRoutes {
|
||||
if _, ok := currentNets[r.Network]; !ok {
|
||||
extra = append(extra, r)
|
||||
}
|
||||
}
|
||||
return extra
|
||||
}
|
||||
|
||||
func filterStatic(routes []*route.Route) []*route.Route {
|
||||
out := make([]*route.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
if !r.IsDynamic() {
|
||||
out = append(out, r)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
|
||||
nets := make([]string, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
|
||||
@@ -1359,6 +1359,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
|
||||
}
|
||||
|
||||
if engine.IsBlockInbound() {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first")
|
||||
}
|
||||
|
||||
mgr := engine.GetExposeManager()
|
||||
if mgr == nil {
|
||||
return gstatus.Errorf(codes.Internal, "expose manager not available")
|
||||
|
||||
@@ -284,19 +284,21 @@ func (s *Server) closeListener(ln net.Listener) {
|
||||
// Stop closes the SSH server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.sshServer == nil {
|
||||
sshServer := s.sshServer
|
||||
if sshServer == nil {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.sshServer = nil
|
||||
s.listener = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
if err := s.sshServer.Close(); err != nil {
|
||||
// Close outside the lock: session handlers need s.mu for unregisterSession.
|
||||
if err := sshServer.Close(); err != nil {
|
||||
log.Debugf("close SSH server: %v", err)
|
||||
}
|
||||
|
||||
s.sshServer = nil
|
||||
s.listener = nil
|
||||
|
||||
s.mu.Lock()
|
||||
maps.Clear(s.sessions)
|
||||
maps.Clear(s.pendingAuthJWT)
|
||||
maps.Clear(s.connections)
|
||||
@@ -307,6 +309,7 @@ func (s *Server) Stop() error {
|
||||
}
|
||||
}
|
||||
maps.Clear(s.remoteForwardListeners)
|
||||
s.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -153,6 +153,9 @@ func networkAddresses() ([]NetworkAddress, error) {
|
||||
|
||||
var netAddresses []NetworkAddress
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
if iface.HardwareAddr.String() == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -324,6 +324,7 @@ type serviceClient struct {
|
||||
exitNodeMu sync.Mutex
|
||||
mExitNodeItems []menuHandler
|
||||
exitNodeRetryCancel context.CancelFunc
|
||||
mExitNodeSeparator *systray.MenuItem
|
||||
mExitNodeDeselectAll *systray.MenuItem
|
||||
logFile string
|
||||
wLoginURL fyne.Window
|
||||
|
||||
@@ -421,6 +421,10 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
|
||||
node.Remove()
|
||||
}
|
||||
s.mExitNodeItems = nil
|
||||
if s.mExitNodeSeparator != nil {
|
||||
s.mExitNodeSeparator.Remove()
|
||||
s.mExitNodeSeparator = nil
|
||||
}
|
||||
if s.mExitNodeDeselectAll != nil {
|
||||
s.mExitNodeDeselectAll.Remove()
|
||||
s.mExitNodeDeselectAll = nil
|
||||
@@ -453,31 +457,37 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
|
||||
}
|
||||
|
||||
if showDeselectAll {
|
||||
s.mExitNode.AddSeparator()
|
||||
deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All")
|
||||
s.mExitNodeDeselectAll = deselectAllItem
|
||||
go func() {
|
||||
for {
|
||||
_, ok := <-deselectAllItem.ClickedCh
|
||||
if !ok {
|
||||
// channel closed: exit the goroutine
|
||||
return
|
||||
}
|
||||
exitNodes, err := s.handleExitNodeMenuDeselectAll()
|
||||
if err != nil {
|
||||
log.Warnf("failed to handle deselect all exit nodes: %v", err)
|
||||
} else {
|
||||
s.exitNodeMu.Lock()
|
||||
s.recreateExitNodeMenu(exitNodes)
|
||||
s.exitNodeMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
s.addExitNodeDeselectAll()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *serviceClient) addExitNodeDeselectAll() {
|
||||
sep := s.mExitNode.AddSubMenuItem("───────────────", "")
|
||||
sep.Disable()
|
||||
s.mExitNodeSeparator = sep
|
||||
|
||||
deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All")
|
||||
s.mExitNodeDeselectAll = deselectAllItem
|
||||
|
||||
go func() {
|
||||
for {
|
||||
_, ok := <-deselectAllItem.ClickedCh
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
exitNodes, err := s.handleExitNodeMenuDeselectAll()
|
||||
if err != nil {
|
||||
log.Warnf("failed to handle deselect all exit nodes: %v", err)
|
||||
} else {
|
||||
s.exitNodeMu.Lock()
|
||||
s.recreateExitNodeMenu(exitNodes)
|
||||
s.exitNodeMu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *serviceClient) getExitNodes(conn proto.DaemonServiceClient) ([]*proto.Network, error) {
|
||||
ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout)
|
||||
defer cancel()
|
||||
|
||||
45
go.mod
45
go.mod
@@ -17,13 +17,13 @@ require (
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/spf13/pflag v1.0.9
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/sys v0.39.0
|
||||
golang.org/x/crypto v0.48.0
|
||||
golang.org/x/sys v0.41.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/grpc v1.77.0
|
||||
google.golang.org/protobuf v1.36.10
|
||||
google.golang.org/grpc v1.79.3
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
)
|
||||
|
||||
@@ -49,6 +49,7 @@ require (
|
||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/go-jose/go-jose/v4 v4.1.3
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/golang/mock v1.6.0
|
||||
@@ -101,21 +102,21 @@ require (
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1
|
||||
github.com/yusufpapurcu/wmi v1.2.4
|
||||
github.com/zcalusic/sysinfo v1.1.3
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0
|
||||
go.opentelemetry.io/otel v1.38.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
||||
go.opentelemetry.io/otel/metric v1.38.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0
|
||||
go.opentelemetry.io/otel v1.42.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.64.0
|
||||
go.opentelemetry.io/otel/metric v1.42.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0
|
||||
go.uber.org/mock v0.5.2
|
||||
go.uber.org/zap v1.27.0
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/mod v0.32.0
|
||||
golang.org/x/net v0.51.0
|
||||
golang.org/x/oauth2 v0.34.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/term v0.38.0
|
||||
golang.org/x/term v0.40.0
|
||||
golang.org/x/time v0.14.0
|
||||
google.golang.org/api v0.257.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -181,7 +182,6 @@ require (
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
||||
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||
github.com/go-ldap/ldap/v3 v3.4.12 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
@@ -249,12 +249,13 @@ require (
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/prometheus/common v0.67.5 // indirect
|
||||
github.com/prometheus/otlptranslator v1.0.0 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/russellhaering/goxmldsig v1.5.0 // indirect
|
||||
github.com/rymdport/portal v0.4.2 // indirect
|
||||
github.com/shirou/gopsutil/v4 v4.25.1 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.2.0 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.2.1 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/cast v1.7.0 // indirect
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
||||
@@ -269,15 +270,15 @@ require (
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.42.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.42.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/image v0.33.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
golang.org/x/tools v0.39.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/tools v0.41.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
)
|
||||
|
||||
90
go.sum
90
go.sum
@@ -487,10 +487,12 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
|
||||
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
|
||||
github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEoIwkU+A6qos=
|
||||
github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM=
|
||||
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
|
||||
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
|
||||
github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk=
|
||||
github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
@@ -511,8 +513,8 @@ github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKd
|
||||
github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs=
|
||||
github.com/shirou/gopsutil/v4 v4.25.1/go.mod h1:RoUCUpndaJFtT+2zsZzzmhvbfGoDCJ7nFXKJf8GqJbI=
|
||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/shoenig/go-m1cpu v0.2.0 h1:t4GNqvPZ84Vjtpboo/kT3pIkbaK3vc+JIlD/Wz1zSFY=
|
||||
github.com/shoenig/go-m1cpu v0.2.0/go.mod h1:KkDOw6m3ZJQAPHbrzkZki4hnx+pDRR1Lo+ldA56wD5w=
|
||||
github.com/shoenig/go-m1cpu v0.2.1 h1:yqRB4fvOge2+FyRXFkXqsyMoqPazv14Yyy+iyccT2E4=
|
||||
github.com/shoenig/go-m1cpu v0.2.1/go.mod h1:KkDOw6m3ZJQAPHbrzkZki4hnx+pDRR1Lo+ldA56wD5w=
|
||||
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
||||
github.com/shoenig/test v1.7.0 h1:eWcHtTXa6QLnBvm0jgEabMRN/uJ4DMV3M8xUGgRkZmk=
|
||||
github.com/shoenig/test v1.7.0/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI=
|
||||
@@ -603,26 +605,26 @@ github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo=
|
||||
github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
|
||||
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.64.0 h1:g0LRDXMX/G1SEZtK8zl8Chm4K6GBwRkjPKE36LxiTYs=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.64.0/go.mod h1:UrgcjnarfdlBDP3GjDIJWe6HTprwSazNjwsI+Ru6hro=
|
||||
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
|
||||
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
|
||||
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
|
||||
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
@@ -633,8 +635,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
|
||||
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
@@ -648,8 +650,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
|
||||
@@ -666,8 +668,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
|
||||
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
@@ -686,8 +688,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
|
||||
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
|
||||
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
@@ -738,8 +740,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
@@ -752,8 +754,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -765,8 +767,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -780,8 +782,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -799,12 +801,12 @@ google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3
|
||||
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
|
||||
google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 h1:Wgl1rcDNThT+Zn47YyCXOXyX/COgMTIdhJ717F0l4xk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
|
||||
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
|
||||
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||
@@ -815,8 +817,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
|
||||
@@ -170,20 +170,66 @@ type Connector struct {
|
||||
}
|
||||
|
||||
// ToStorageConnector converts a Connector to storage.Connector type.
|
||||
// It maps custom connector types (e.g., "zitadel", "entra") to Dex-native types
|
||||
// and augments the config with OIDC defaults when needed.
|
||||
func (c *Connector) ToStorageConnector() (storage.Connector, error) {
|
||||
data, err := json.Marshal(c.Config)
|
||||
dexType, augmentedConfig := mapConnectorToDex(c.Type, c.Config)
|
||||
|
||||
data, err := json.Marshal(augmentedConfig)
|
||||
if err != nil {
|
||||
return storage.Connector{}, fmt.Errorf("failed to marshal connector config: %v", err)
|
||||
}
|
||||
|
||||
return storage.Connector{
|
||||
ID: c.ID,
|
||||
Type: c.Type,
|
||||
Type: dexType,
|
||||
Name: c.Name,
|
||||
Config: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mapConnectorToDex maps custom connector types to Dex-native types and applies
|
||||
// OIDC defaults. This ensures static connectors from config files or env vars
|
||||
// are stored with types that Dex can open.
|
||||
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
|
||||
switch connType {
|
||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||
return "oidc", applyOIDCDefaults(connType, config)
|
||||
default:
|
||||
return connType, config
|
||||
}
|
||||
}
|
||||
|
||||
// applyOIDCDefaults clones the config map, sets common OIDC defaults,
|
||||
// and applies provider-specific overrides.
|
||||
func applyOIDCDefaults(connType string, config map[string]interface{}) map[string]interface{} {
|
||||
augmented := make(map[string]interface{}, len(config)+4)
|
||||
for k, v := range config {
|
||||
augmented[k] = v
|
||||
}
|
||||
setDefault(augmented, "scopes", []string{"openid", "profile", "email"})
|
||||
setDefault(augmented, "insecureEnableGroups", true)
|
||||
setDefault(augmented, "insecureSkipEmailVerified", true)
|
||||
|
||||
switch connType {
|
||||
case "zitadel":
|
||||
setDefault(augmented, "getUserInfo", true)
|
||||
case "entra":
|
||||
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
|
||||
case "okta", "pocketid":
|
||||
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
}
|
||||
|
||||
return augmented
|
||||
}
|
||||
|
||||
// setDefault sets a key in the map only if it doesn't already exist.
|
||||
func setDefault(m map[string]interface{}, key string, value interface{}) {
|
||||
if _, ok := m[key]; !ok {
|
||||
m[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
// StorageConfig is a configuration that can create a storage.
|
||||
type StorageConfig interface {
|
||||
Open(logger *slog.Logger) (storage.Storage, error)
|
||||
|
||||
@@ -4,6 +4,7 @@ package dex
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -19,10 +20,13 @@ import (
|
||||
"github.com/dexidp/dex/server"
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/sql"
|
||||
jose "github.com/go-jose/go-jose/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
// Config matches what management/internals/server/server.go expects
|
||||
@@ -666,3 +670,46 @@ func (p *Provider) GetAuthorizationEndpoint() string {
|
||||
}
|
||||
return issuer + "/auth"
|
||||
}
|
||||
|
||||
// GetJWKS reads signing keys directly from Dex storage and returns them as Jwks.
|
||||
// This avoids HTTP round-trips when the embedded IDP is co-located with the management server.
|
||||
// The key retrieval mirrors Dex's own handlePublicKeys/ValidationKeys logic:
|
||||
// SigningKeyPub first, then all VerificationKeys, serialized via go-jose.
|
||||
func (p *Provider) GetJWKS(ctx context.Context) (*nbjwt.Jwks, error) {
|
||||
keys, err := p.storage.GetKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get keys from storage: %w", err)
|
||||
}
|
||||
|
||||
if keys.SigningKeyPub == nil {
|
||||
return nil, fmt.Errorf("no public keys found in storage")
|
||||
}
|
||||
|
||||
// Build the key set exactly as Dex's localSigner.ValidationKeys does:
|
||||
// signing key first, then all verification (rotated) keys.
|
||||
joseKeys := make([]jose.JSONWebKey, 0, len(keys.VerificationKeys)+1)
|
||||
joseKeys = append(joseKeys, *keys.SigningKeyPub)
|
||||
for _, vk := range keys.VerificationKeys {
|
||||
if vk.PublicKey != nil {
|
||||
joseKeys = append(joseKeys, *vk.PublicKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize through go-jose (same as Dex's handlePublicKeys handler)
|
||||
// then deserialize into our Jwks type, so the JSON field mapping is identical
|
||||
// to what the /keys HTTP endpoint would return.
|
||||
joseSet := jose.JSONWebKeySet{Keys: joseKeys}
|
||||
data, err := json.Marshal(joseSet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal JWKS: %w", err)
|
||||
}
|
||||
|
||||
jwks := &nbjwt.Jwks{}
|
||||
if err := json.Unmarshal(data, jwks); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal JWKS: %w", err)
|
||||
}
|
||||
|
||||
jwks.ExpiresInTime = keys.NextRotation
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
@@ -2,11 +2,14 @@ package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
sqllib "github.com/dexidp/dex/storage/sql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -197,6 +200,295 @@ enablePasswordDB: true
|
||||
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
|
||||
}
|
||||
|
||||
// openTestStorage creates a SQLite storage in the given directory for testing.
|
||||
func openTestStorage(t *testing.T, tmpDir string) storage.Storage {
|
||||
t.Helper()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
stor, err := (&sqllib.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
|
||||
require.NoError(t, err)
|
||||
return stor
|
||||
}
|
||||
|
||||
func TestStaticConnectors_CreatedFromYAML(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "dex-static-conn-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
yamlContent := `
|
||||
issuer: http://localhost:5556/dex
|
||||
storage:
|
||||
type: sqlite3
|
||||
config:
|
||||
file: ` + filepath.Join(tmpDir, "dex.db") + `
|
||||
web:
|
||||
http: 127.0.0.1:5556
|
||||
enablePasswordDB: true
|
||||
connectors:
|
||||
- type: oidc
|
||||
id: my-oidc
|
||||
name: My OIDC Provider
|
||||
config:
|
||||
issuer: https://accounts.example.com
|
||||
clientID: test-client-id
|
||||
clientSecret: test-client-secret
|
||||
redirectURI: http://localhost:5556/dex/callback
|
||||
`
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
yamlConfig, err := LoadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Open storage and run initializeStorage directly (avoids Dex server
|
||||
// trying to dial the OIDC issuer)
|
||||
stor := openTestStorage(t, tmpDir)
|
||||
defer stor.Close()
|
||||
|
||||
err = initializeStorage(ctx, stor, yamlConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify connector was created in storage
|
||||
conn, err := stor.GetConnector(ctx, "my-oidc")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "my-oidc", conn.ID)
|
||||
assert.Equal(t, "My OIDC Provider", conn.Name)
|
||||
assert.Equal(t, "oidc", conn.Type)
|
||||
|
||||
// Verify config fields were serialized correctly
|
||||
var configMap map[string]interface{}
|
||||
err = json.Unmarshal(conn.Config, &configMap)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://accounts.example.com", configMap["issuer"])
|
||||
assert.Equal(t, "test-client-id", configMap["clientID"])
|
||||
}
|
||||
|
||||
func TestStaticConnectors_UpdatedOnRestart(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "dex-static-conn-update-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbFile := filepath.Join(tmpDir, "dex.db")
|
||||
|
||||
// First: load config with initial connector
|
||||
yamlContent1 := `
|
||||
issuer: http://localhost:5556/dex
|
||||
storage:
|
||||
type: sqlite3
|
||||
config:
|
||||
file: ` + dbFile + `
|
||||
web:
|
||||
http: 127.0.0.1:5556
|
||||
enablePasswordDB: true
|
||||
connectors:
|
||||
- type: oidc
|
||||
id: my-oidc
|
||||
name: Original Name
|
||||
config:
|
||||
issuer: https://accounts.example.com
|
||||
clientID: original-client-id
|
||||
clientSecret: original-secret
|
||||
`
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
err = os.WriteFile(configPath, []byte(yamlContent1), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
yamlConfig1, err := LoadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
stor := openTestStorage(t, tmpDir)
|
||||
err = initializeStorage(ctx, stor, yamlConfig1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify initial state
|
||||
conn, err := stor.GetConnector(ctx, "my-oidc")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Original Name", conn.Name)
|
||||
|
||||
var configMap1 map[string]interface{}
|
||||
err = json.Unmarshal(conn.Config, &configMap1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "original-client-id", configMap1["clientID"])
|
||||
|
||||
// Close storage to simulate restart
|
||||
stor.Close()
|
||||
|
||||
// Second: load updated config against the same DB
|
||||
yamlContent2 := `
|
||||
issuer: http://localhost:5556/dex
|
||||
storage:
|
||||
type: sqlite3
|
||||
config:
|
||||
file: ` + dbFile + `
|
||||
web:
|
||||
http: 127.0.0.1:5556
|
||||
enablePasswordDB: true
|
||||
connectors:
|
||||
- type: oidc
|
||||
id: my-oidc
|
||||
name: Updated Name
|
||||
config:
|
||||
issuer: https://accounts.example.com
|
||||
clientID: updated-client-id
|
||||
clientSecret: updated-secret
|
||||
`
|
||||
err = os.WriteFile(configPath, []byte(yamlContent2), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
yamlConfig2, err := LoadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
stor2 := openTestStorage(t, tmpDir)
|
||||
defer stor2.Close()
|
||||
|
||||
err = initializeStorage(ctx, stor2, yamlConfig2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify connector was updated, not duplicated
|
||||
allConnectors, err := stor2.ListConnectors(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
nonLocalCount := 0
|
||||
for _, c := range allConnectors {
|
||||
if c.ID != "local" {
|
||||
nonLocalCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, nonLocalCount, "connector should be updated, not duplicated")
|
||||
|
||||
conn2, err := stor2.GetConnector(ctx, "my-oidc")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Name", conn2.Name)
|
||||
|
||||
var configMap2 map[string]interface{}
|
||||
err = json.Unmarshal(conn2.Config, &configMap2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated-client-id", configMap2["clientID"])
|
||||
}
|
||||
|
||||
func TestStaticConnectors_MultipleConnectors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "dex-static-conn-multi-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
yamlContent := `
|
||||
issuer: http://localhost:5556/dex
|
||||
storage:
|
||||
type: sqlite3
|
||||
config:
|
||||
file: ` + filepath.Join(tmpDir, "dex.db") + `
|
||||
web:
|
||||
http: 127.0.0.1:5556
|
||||
enablePasswordDB: true
|
||||
connectors:
|
||||
- type: oidc
|
||||
id: my-oidc
|
||||
name: My OIDC Provider
|
||||
config:
|
||||
issuer: https://accounts.example.com
|
||||
clientID: oidc-client-id
|
||||
clientSecret: oidc-secret
|
||||
- type: google
|
||||
id: my-google
|
||||
name: Google Login
|
||||
config:
|
||||
clientID: google-client-id
|
||||
clientSecret: google-secret
|
||||
`
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
yamlConfig, err := LoadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
stor := openTestStorage(t, tmpDir)
|
||||
defer stor.Close()
|
||||
|
||||
err = initializeStorage(ctx, stor, yamlConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
allConnectors, err := stor.ListConnectors(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build a map for easier assertion
|
||||
connByID := make(map[string]storage.Connector)
|
||||
for _, c := range allConnectors {
|
||||
connByID[c.ID] = c
|
||||
}
|
||||
|
||||
// Verify both static connectors exist
|
||||
oidcConn, ok := connByID["my-oidc"]
|
||||
require.True(t, ok, "oidc connector should exist")
|
||||
assert.Equal(t, "My OIDC Provider", oidcConn.Name)
|
||||
assert.Equal(t, "oidc", oidcConn.Type)
|
||||
|
||||
var oidcConfig map[string]interface{}
|
||||
err = json.Unmarshal(oidcConn.Config, &oidcConfig)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "oidc-client-id", oidcConfig["clientID"])
|
||||
|
||||
googleConn, ok := connByID["my-google"]
|
||||
require.True(t, ok, "google connector should exist")
|
||||
assert.Equal(t, "Google Login", googleConn.Name)
|
||||
assert.Equal(t, "google", googleConn.Type)
|
||||
|
||||
var googleConfig map[string]interface{}
|
||||
err = json.Unmarshal(googleConn.Config, &googleConfig)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "google-client-id", googleConfig["clientID"])
|
||||
|
||||
// Verify local connector still exists alongside them (enablePasswordDB: true)
|
||||
localConn, ok := connByID["local"]
|
||||
require.True(t, ok, "local connector should exist")
|
||||
assert.Equal(t, "local", localConn.Type)
|
||||
}
|
||||
|
||||
func TestStaticConnectors_EmptyList(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "dex-static-conn-empty-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
yamlContent := `
|
||||
issuer: http://localhost:5556/dex
|
||||
storage:
|
||||
type: sqlite3
|
||||
config:
|
||||
file: ` + filepath.Join(tmpDir, "dex.db") + `
|
||||
web:
|
||||
http: 127.0.0.1:5556
|
||||
enablePasswordDB: true
|
||||
`
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
yamlConfig, err := LoadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
provider, err := NewProviderFromYAML(ctx, yamlConfig)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = provider.Stop(ctx) }()
|
||||
|
||||
// No static connectors configured, so ListConnectors should return empty
|
||||
connectors, err := provider.ListConnectors(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, connectors)
|
||||
|
||||
// But local connector should still exist
|
||||
localConn, err := provider.Storage().GetConnector(ctx, "local")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "local", localConn.ID)
|
||||
}
|
||||
|
||||
func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -172,8 +172,11 @@ init_environment() {
|
||||
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||
echo ""
|
||||
echo "Login with the following credentials:"
|
||||
echo "Email: admin@$NETBIRD_DOMAIN" | tee .env
|
||||
echo "Password: $NETBIRD_ADMIN_PASSWORD" | tee -a .env
|
||||
install -m 600 /dev/null .env
|
||||
printf 'Email: admin@%s\nPassword: %s\n' \
|
||||
"$NETBIRD_DOMAIN" "$NETBIRD_ADMIN_PASSWORD" >> .env
|
||||
echo "Email: admin@$NETBIRD_DOMAIN"
|
||||
echo "Password: $NETBIRD_ADMIN_PASSWORD"
|
||||
echo ""
|
||||
echo "Dex admin UI is not available (Dex has no built-in UI)."
|
||||
echo "To add more users, edit dex.yaml and restart: $DOCKER_COMPOSE_COMMAND restart dex"
|
||||
|
||||
@@ -563,8 +563,11 @@ initEnvironment() {
|
||||
echo -e "\nDone!\n"
|
||||
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||
echo "Login with the following credentials:"
|
||||
echo "Username: $ZITADEL_ADMIN_USERNAME" | tee .env
|
||||
echo "Password: $ZITADEL_ADMIN_PASSWORD" | tee -a .env
|
||||
install -m 600 /dev/null .env
|
||||
printf 'Username: %s\nPassword: %s\n' \
|
||||
"$ZITADEL_ADMIN_USERNAME" "$ZITADEL_ADMIN_PASSWORD" >> .env
|
||||
echo "Username: $ZITADEL_ADMIN_USERNAME"
|
||||
echo "Password: $ZITADEL_ADMIN_PASSWORD"
|
||||
}
|
||||
|
||||
renderCaddyfile() {
|
||||
|
||||
@@ -1154,7 +1154,16 @@ print_builtin_traefik_instructions() {
|
||||
echo " - $NETBIRD_STUN_PORT/udp (STUN - required for NAT traversal)"
|
||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
|
||||
echo ""
|
||||
fi
|
||||
echo ""
|
||||
echo "This setup is ideal for homelabs and smaller organization deployments."
|
||||
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||
echo "consider a commercial on-prem license or scaling your open source deployment:"
|
||||
echo ""
|
||||
echo " Commercial license: https://netbird.io/pricing#on-prem"
|
||||
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
|
||||
echo ""
|
||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||
echo "NetBird Proxy:"
|
||||
echo " The proxy service is enabled and running."
|
||||
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."
|
||||
|
||||
@@ -154,9 +154,11 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
||||
return err
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
|
||||
})
|
||||
if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -31,19 +31,15 @@ type store interface {
|
||||
|
||||
type proxyManager interface {
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
type clusterCapabilities interface {
|
||||
ClusterSupportsCustomPorts(clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(clusterAddr string) *bool
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyManager proxyManager
|
||||
clusterCapabilities clusterCapabilities
|
||||
permissionsManager permissions.Manager
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyManager proxyManager
|
||||
permissionsManager permissions.Manager
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
@@ -57,11 +53,6 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio
|
||||
}
|
||||
}
|
||||
|
||||
// SetClusterCapabilities sets the cluster capabilities provider for domain queries.
|
||||
func (m *Manager) SetClusterCapabilities(caps clusterCapabilities) {
|
||||
m.clusterCapabilities = caps
|
||||
}
|
||||
|
||||
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
@@ -97,10 +88,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
Type: domain.TypeFree,
|
||||
Validated: true,
|
||||
}
|
||||
if m.clusterCapabilities != nil {
|
||||
d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster)
|
||||
d.RequireSubdomain = m.clusterCapabilities.ClusterRequireSubdomain(cluster)
|
||||
}
|
||||
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
|
||||
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
|
||||
ret = append(ret, d)
|
||||
}
|
||||
|
||||
@@ -114,8 +103,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
Type: domain.TypeCustom,
|
||||
Validated: d.Validated,
|
||||
}
|
||||
if m.clusterCapabilities != nil && d.TargetCluster != "" {
|
||||
cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster)
|
||||
if d.TargetCluster != "" {
|
||||
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
|
||||
}
|
||||
// Custom domains never require a subdomain by default since
|
||||
// the account owns them and should be able to use the bare domain.
|
||||
|
||||
@@ -11,11 +11,13 @@ import (
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
|
||||
Disconnect(ctx context.Context, proxyID string) error
|
||||
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
@@ -34,6 +36,4 @@ type Controller interface {
|
||||
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
GetProxiesForCluster(clusterAddr string) []string
|
||||
ClusterSupportsCustomPorts(clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(clusterAddr string) *bool
|
||||
}
|
||||
|
||||
@@ -72,17 +72,6 @@ func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, cluster
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any proxy in the cluster supports custom ports.
|
||||
func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
|
||||
return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain returns whether the cluster requires a subdomain label.
|
||||
// Returns nil when no proxy has reported the capability (defaults to false).
|
||||
func (c *GRPCController) ClusterRequireSubdomain(clusterAddr string) *bool {
|
||||
return c.proxyGRPCServer.ClusterRequireSubdomain(clusterAddr)
|
||||
}
|
||||
|
||||
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
|
||||
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
proxySet, ok := c.clusterProxies.Load(clusterAddr)
|
||||
|
||||
@@ -16,6 +16,8 @@ type store interface {
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
@@ -38,9 +40,14 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Connect registers a new proxy connection in the database
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
// Connect registers a new proxy connection in the database.
|
||||
// capabilities may be nil for old proxies that do not report them.
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
|
||||
now := time.Now()
|
||||
var caps proxy.Capabilities
|
||||
if capabilities != nil {
|
||||
caps = *capabilities
|
||||
}
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
ClusterAddress: clusterAddress,
|
||||
@@ -48,6 +55,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
Capabilities: caps,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
@@ -118,6 +126,18 @@ func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||
// supports custom ports. Returns nil when no proxy has reported capabilities.
|
||||
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
return m.store.GetClusterSupportsCustomPorts(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain returns whether any active proxy in the cluster
|
||||
// requires a subdomain. Returns nil when no proxy has reported capabilities.
|
||||
func (m Manager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
|
||||
return m.store.GetClusterRequireSubdomain(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
|
||||
@@ -50,18 +50,46 @@ func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interfac
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
// ClusterSupportsCustomPorts mocks base method.
|
||||
func (m *MockManager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts.
|
||||
func (mr *MockManagerMockRecorder) ClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCustomPorts), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain mocks base method.
|
||||
func (m *MockManager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClusterRequireSubdomain", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain.
|
||||
func (mr *MockManagerMockRecorder) ClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockManager)(nil).ClusterRequireSubdomain), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
@@ -145,34 +173,6 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts mocks base method.
|
||||
func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts.
|
||||
func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain mocks base method.
|
||||
func (m *MockController) ClusterRequireSubdomain(clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClusterRequireSubdomain", clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain.
|
||||
func (mr *MockControllerMockRecorder) ClusterRequireSubdomain(clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockController)(nil).ClusterRequireSubdomain), clusterAddr)
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig mocks base method.
|
||||
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -2,6 +2,17 @@ package proxy
|
||||
|
||||
import "time"
|
||||
|
||||
// Capabilities describes what a proxy can handle, as reported via gRPC.
|
||||
// Nil fields mean the proxy never reported this capability.
|
||||
type Capabilities struct {
|
||||
// SupportsCustomPorts indicates whether this proxy can bind arbitrary
|
||||
// ports for TCP/UDP services. TLS uses SNI routing and is not gated.
|
||||
SupportsCustomPorts *bool
|
||||
// RequireSubdomain indicates whether a subdomain label is required in
|
||||
// front of the cluster domain.
|
||||
RequireSubdomain *bool
|
||||
}
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
type Proxy struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
@@ -11,6 +22,7 @@ type Proxy struct {
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Capabilities Capabilities `gorm:"embedded"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
@@ -75,16 +75,18 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
|
||||
require.NoError(t, err)
|
||||
|
||||
mockCtrl := proxy.NewMockController(ctrl)
|
||||
mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes()
|
||||
mockCtrl.EXPECT().ClusterRequireSubdomain(gomock.Any()).Return((*bool)(nil)).AnyTimes()
|
||||
mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
|
||||
|
||||
mockCaps := proxy.NewMockManager(ctrl)
|
||||
mockCaps.EXPECT().ClusterSupportsCustomPorts(gomock.Any(), testCluster).Return(customPortsSupported).AnyTimes()
|
||||
mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes()
|
||||
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -93,6 +95,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
|
||||
accountManager: accountMgr,
|
||||
permissionsManager: permissions.NewManager(testStore),
|
||||
proxyController: mockCtrl,
|
||||
capabilities: mockCaps,
|
||||
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
|
||||
}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
|
||||
@@ -75,22 +75,30 @@ type ClusterDeriver interface {
|
||||
GetClusterDomains() []string
|
||||
}
|
||||
|
||||
// CapabilityProvider queries proxy cluster capabilities from the database.
|
||||
type CapabilityProvider interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
proxyController proxy.Controller
|
||||
capabilities CapabilityProvider
|
||||
clusterDeriver ClusterDeriver
|
||||
exposeReaper *exposeReaper
|
||||
}
|
||||
|
||||
// NewManager creates a new service manager.
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
|
||||
mgr := &Manager{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
proxyController: proxyController,
|
||||
capabilities: capabilities,
|
||||
clusterDeriver: clusterDeriver,
|
||||
}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
@@ -237,7 +245,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
|
||||
}
|
||||
service.ProxyCluster = proxyCluster
|
||||
|
||||
if err := m.validateSubdomainRequirement(service.Domain, proxyCluster); err != nil {
|
||||
if err := m.validateSubdomainRequirement(ctx, service.Domain, proxyCluster); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -268,11 +276,11 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
|
||||
// validateSubdomainRequirement checks whether the domain can be used bare
|
||||
// (without a subdomain label) on the given cluster. If the cluster reports
|
||||
// require_subdomain=true and the domain equals the cluster domain, it rejects.
|
||||
func (m *Manager) validateSubdomainRequirement(domain, cluster string) error {
|
||||
func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, cluster string) error {
|
||||
if domain != cluster {
|
||||
return nil
|
||||
}
|
||||
requireSub := m.proxyController.ClusterRequireSubdomain(cluster)
|
||||
requireSub := m.capabilities.ClusterRequireSubdomain(ctx, cluster)
|
||||
if requireSub != nil && *requireSub {
|
||||
return status.Errorf(status.InvalidArgument, "domain %s requires a subdomain label", domain)
|
||||
}
|
||||
@@ -280,6 +288,8 @@ func (m *Manager) validateSubdomainRequirement(domain, cluster string) error {
|
||||
}
|
||||
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
|
||||
customPorts := m.clusterCustomPorts(ctx, svc)
|
||||
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if svc.Domain != "" {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||
@@ -287,7 +297,7 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -307,12 +317,23 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
|
||||
})
|
||||
}
|
||||
|
||||
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
|
||||
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error {
|
||||
// clusterCustomPorts queries whether the cluster supports custom ports.
|
||||
// Must be called before entering a transaction: the underlying query uses
|
||||
// the main DB handle, which deadlocks when called inside a transaction
|
||||
// that already holds the connection.
|
||||
func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service) *bool {
|
||||
if !service.IsL4Protocol(svc.Mode) {
|
||||
return nil
|
||||
}
|
||||
return m.capabilities.ClusterSupportsCustomPorts(ctx, svc.ProxyCluster)
|
||||
}
|
||||
|
||||
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
|
||||
// customPorts must be pre-computed via clusterCustomPorts before entering a transaction.
|
||||
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error {
|
||||
if !service.IsL4Protocol(svc.Mode) {
|
||||
return nil
|
||||
}
|
||||
customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster)
|
||||
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
|
||||
@@ -396,12 +417,14 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string
|
||||
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
|
||||
// for the same peer, preventing the per-peer limit from being bypassed.
|
||||
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
|
||||
customPorts := m.clusterCustomPorts(ctx, svc)
|
||||
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -504,21 +527,58 @@ type serviceUpdateInfo struct {
|
||||
}
|
||||
|
||||
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
|
||||
effectiveCluster, err := m.resolveEffectiveCluster(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svcForCaps := *service
|
||||
svcForCaps.ProxyCluster = effectiveCluster
|
||||
customPorts := m.clusterCustomPorts(ctx, &svcForCaps)
|
||||
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts)
|
||||
})
|
||||
|
||||
return &updateInfo, err
|
||||
}
|
||||
|
||||
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo) error {
|
||||
// resolveEffectiveCluster determines the cluster that will be used after the update.
|
||||
// It reads the existing service without locking and derives the new cluster if the domain changed.
|
||||
func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string, svc *service.Service) (string, error) {
|
||||
existing, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, svc.ID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if existing.Domain == svc.Domain {
|
||||
return existing.ProxyCluster, nil
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
derived, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
|
||||
} else {
|
||||
return derived, nil
|
||||
}
|
||||
}
|
||||
|
||||
return existing.ProxyCluster, nil
|
||||
}
|
||||
|
||||
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if existingService.Terminated {
|
||||
return status.Errorf(status.PermissionDenied, "service is terminated and cannot be updated")
|
||||
}
|
||||
|
||||
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -534,7 +594,7 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
if err := m.validateSubdomainRequirement(service.Domain, service.ProxyCluster); err != nil {
|
||||
if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -546,7 +606,7 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
m.preserveListenPort(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
|
||||
@@ -1059,7 +1119,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
|
||||
}
|
||||
groupIDs := make([]string, 0, len(groupNames))
|
||||
for _, groupName := range groupNames {
|
||||
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
|
||||
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
|
||||
}
|
||||
|
||||
@@ -698,8 +698,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1324,11 +1324,11 @@ func TestValidateSubdomainRequirement(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
mockCtrl := proxy.NewMockController(ctrl)
|
||||
mockCtrl.EXPECT().ClusterRequireSubdomain(tc.cluster).Return(tc.requireSubdomain).AnyTimes()
|
||||
mockCaps := proxy.NewMockManager(ctrl)
|
||||
mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), tc.cluster).Return(tc.requireSubdomain).AnyTimes()
|
||||
|
||||
mgr := &Manager{proxyController: mockCtrl}
|
||||
err := mgr.validateSubdomainRequirement(tc.domain, tc.cluster)
|
||||
mgr := &Manager{capabilities: mockCaps}
|
||||
err := mgr.validateSubdomainRequirement(context.Background(), tc.domain, tc.cluster)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "requires a subdomain label")
|
||||
|
||||
@@ -184,6 +184,7 @@ type Service struct {
|
||||
ProxyCluster string `gorm:"index"`
|
||||
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
||||
Enabled bool
|
||||
Terminated bool
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
@@ -256,7 +257,7 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
Protocol: api.ServiceTargetProtocol(target.Protocol),
|
||||
TargetId: target.TargetId,
|
||||
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
||||
Enabled: target.Enabled,
|
||||
Enabled: target.Enabled && !s.Terminated,
|
||||
}
|
||||
opts := targetOptionsToAPI(target.Options)
|
||||
if opts == nil {
|
||||
@@ -286,7 +287,8 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
Name: s.Name,
|
||||
Domain: s.Domain,
|
||||
Targets: apiTargets,
|
||||
Enabled: s.Enabled,
|
||||
Enabled: s.Enabled && !s.Terminated,
|
||||
Terminated: &s.Terminated,
|
||||
PassHostHeader: &s.PassHostHeader,
|
||||
RewriteRedirects: &s.RewriteRedirects,
|
||||
Auth: authConfig,
|
||||
@@ -785,6 +787,11 @@ func (s *Service) validateHTTPTargets() error {
|
||||
}
|
||||
|
||||
func (s *Service) validateL4Target(target *Target) error {
|
||||
// L4 services have a single target; per-target disable is meaningless
|
||||
// (use the service-level Enabled flag instead). Force it on so that
|
||||
// buildPathMappings always includes the target in the proto.
|
||||
target.Enabled = true
|
||||
|
||||
if target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
@@ -1125,6 +1132,7 @@ func (s *Service) Copy() *Service {
|
||||
ProxyCluster: s.ProxyCluster,
|
||||
Targets: targets,
|
||||
Enabled: s.Enabled,
|
||||
Terminated: s.Terminated,
|
||||
PassHostHeader: s.PassHostHeader,
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
Auth: authCopy,
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
||||
@@ -71,6 +72,7 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled
|
||||
issuer := s.Config.HttpConfig.AuthIssuer
|
||||
userIDClaim := s.Config.HttpConfig.AuthUserIDClaim
|
||||
var keyFetcher nbjwt.KeyFetcher
|
||||
|
||||
// Use embedded IdP configuration if available
|
||||
if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil {
|
||||
@@ -78,8 +80,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
if len(audiences) > 0 {
|
||||
audience = audiences[0] // Use the first client ID as the primary audience
|
||||
}
|
||||
// Use localhost keys location for internal validation (management has embedded Dex)
|
||||
keysLocation = oauthProvider.GetLocalKeysLocation()
|
||||
keyFetcher = oauthProvider.GetKeyFetcher()
|
||||
// Fall back to default keys location if direct key fetching is not available
|
||||
if keyFetcher == nil {
|
||||
keysLocation = oauthProvider.GetLocalKeysLocation()
|
||||
}
|
||||
signingKeyRefreshEnabled = true
|
||||
issuer = oauthProvider.GetIssuer()
|
||||
userIDClaim = oauthProvider.GetUserIDClaim()
|
||||
@@ -92,7 +97,8 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
keysLocation,
|
||||
userIDClaim,
|
||||
audiences,
|
||||
signingKeyRefreshEnabled)
|
||||
signingKeyRefreshEnabled,
|
||||
keyFetcher)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -117,9 +117,11 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return Create(s, func() idp.Manager {
|
||||
var idpManager idp.Manager
|
||||
var err error
|
||||
|
||||
// Use embedded IdP service if embedded Dex is configured and enabled.
|
||||
// Legacy IdpManager won't be used anymore even if configured.
|
||||
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||
embeddedEnabled := s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled
|
||||
if embeddedEnabled {
|
||||
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create embedded IDP service: %v", err)
|
||||
@@ -195,7 +197,7 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
||||
|
||||
func (s *BaseServer) ServiceManager() service.Manager {
|
||||
return Create(s, func() service.Manager {
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -212,9 +214,6 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
|
||||
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
return Create(s, func() *manager.Manager {
|
||||
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
m.SetClusterCapabilities(s.ServiceProxyController())
|
||||
})
|
||||
return &m
|
||||
})
|
||||
}
|
||||
|
||||
@@ -182,9 +182,21 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
// Register proxy in database
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
|
||||
// Register proxy in database with capabilities
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
}
|
||||
}
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
@@ -297,6 +309,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
continue
|
||||
}
|
||||
mappings = append(mappings, m)
|
||||
}
|
||||
return mappings, nil
|
||||
@@ -445,22 +460,46 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
|
||||
log.Debugf("Sending service update to cluster %s", clusterAddr)
|
||||
for _, proxyID := range proxyIDs {
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
}
|
||||
connVal, ok := s.connectedProxies.Load(proxyID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
conn := connVal.(*proxyConnection)
|
||||
if !proxyAcceptsMapping(conn, update) {
|
||||
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
|
||||
continue
|
||||
}
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
|
||||
// Old proxies that never reported capabilities are skipped for non-TLS L4
|
||||
// mappings with a custom listen port, since they don't understand the
|
||||
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
|
||||
// are new enough to handle the mapping. TLS uses SNI routing and works on
|
||||
// any proxy. Delete operations are always sent so proxies can clean up.
|
||||
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
|
||||
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||
return true
|
||||
}
|
||||
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
|
||||
return true
|
||||
}
|
||||
// Old proxies that never reported capabilities don't understand
|
||||
// custom port mappings.
|
||||
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
|
||||
}
|
||||
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
@@ -508,64 +547,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
}
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any connected proxy in the given
|
||||
// cluster reports custom port support. Returns nil if no proxy has reported
|
||||
// capabilities (old proxies that predate the field).
|
||||
func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool {
|
||||
if s.proxyController == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var hasCapabilities bool
|
||||
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
|
||||
connVal, ok := s.connectedProxies.Load(pid)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
conn := connVal.(*proxyConnection)
|
||||
if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil {
|
||||
continue
|
||||
}
|
||||
if *conn.capabilities.SupportsCustomPorts {
|
||||
return ptr(true)
|
||||
}
|
||||
hasCapabilities = true
|
||||
}
|
||||
if hasCapabilities {
|
||||
return ptr(false)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain returns whether any connected proxy in the given
|
||||
// cluster reports that a subdomain is required. Returns nil if no proxy has
|
||||
// reported the capability (defaults to not required).
|
||||
func (s *ProxyServiceServer) ClusterRequireSubdomain(clusterAddr string) *bool {
|
||||
if s.proxyController == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var hasCapabilities bool
|
||||
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
|
||||
connVal, ok := s.connectedProxies.Load(pid)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
conn := connVal.(*proxyConnection)
|
||||
if conn.capabilities == nil || conn.capabilities.RequireSubdomain == nil {
|
||||
continue
|
||||
}
|
||||
if *conn.capabilities.RequireSubdomain {
|
||||
return ptr(true)
|
||||
}
|
||||
hasCapabilities = true
|
||||
}
|
||||
if hasCapabilities {
|
||||
return ptr(false)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
|
||||
@@ -53,14 +53,6 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool {
|
||||
return ptr(true)
|
||||
}
|
||||
|
||||
func (c *testProxyController) ClusterRequireSubdomain(_ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
@@ -355,14 +347,14 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
|
||||
// Proxy A supports custom ports.
|
||||
chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
|
||||
// Proxy B does NOT support custom ports (shared cloud proxy).
|
||||
chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
// Modern proxy reports capabilities.
|
||||
chModern := registerFakeProxyWithCaps(s, "proxy-modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
|
||||
// Legacy proxy never reported capabilities (nil).
|
||||
chLegacy := registerFakeProxy(s, "proxy-legacy", cluster)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// TLS passthrough works on all proxies regardless of custom port support.
|
||||
// TLS passthrough with custom port: all proxies receive it (SNI routing).
|
||||
tlsMapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-tls",
|
||||
@@ -375,12 +367,26 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
|
||||
|
||||
s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster)
|
||||
|
||||
msgA := drainMapping(chA)
|
||||
msgB := drainMapping(chB)
|
||||
assert.NotNil(t, msgA, "proxy-a should receive TLS mapping")
|
||||
assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)")
|
||||
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TLS mapping")
|
||||
assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive TLS mapping (SNI works on all)")
|
||||
|
||||
// Send an HTTP mapping: both should receive it.
|
||||
// TCP mapping with custom port: only modern proxy receives it.
|
||||
tcpMapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-tcp",
|
||||
AccountId: "account-1",
|
||||
Domain: "db.example.com",
|
||||
Mode: "tcp",
|
||||
ListenPort: 5432,
|
||||
Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}},
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster)
|
||||
|
||||
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TCP custom-port mapping")
|
||||
assert.Nil(t, drainMapping(chLegacy), "legacy proxy should NOT receive TCP custom-port mapping")
|
||||
|
||||
// HTTP mapping (no listen port): both receive it.
|
||||
httpMapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-http",
|
||||
@@ -391,10 +397,16 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
|
||||
|
||||
s.SendServiceUpdateToCluster(ctx, httpMapping, cluster)
|
||||
|
||||
msgA = drainMapping(chA)
|
||||
msgB = drainMapping(chB)
|
||||
assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping")
|
||||
assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping")
|
||||
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive HTTP mapping")
|
||||
assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive HTTP mapping")
|
||||
|
||||
// Proxy that reports SupportsCustomPorts=false still receives custom-port
|
||||
// mappings because it understands the protocol (it's new enough).
|
||||
chNewNoCustom := registerFakeProxyWithCaps(s, "proxy-new-no-custom", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
|
||||
s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster)
|
||||
|
||||
assert.NotNil(t, drainMapping(chNewNoCustom), "new proxy with SupportsCustomPorts=false should still receive mapping")
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
|
||||
@@ -408,7 +420,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
|
||||
chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
// Legacy proxy (no capabilities) still receives TLS since it uses SNI.
|
||||
chLegacy := registerFakeProxy(s, "proxy-legacy", cluster)
|
||||
|
||||
tlsMapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
@@ -421,8 +434,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
|
||||
|
||||
s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster)
|
||||
|
||||
msg := drainMapping(chShared)
|
||||
assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support")
|
||||
msg := drainMapping(chLegacy)
|
||||
assert.NotNil(t, msg, "legacy proxy should receive TLS mapping (SNI works without custom port support)")
|
||||
}
|
||||
|
||||
// TestServiceModifyNotifications exercises every possible modification
|
||||
@@ -589,7 +602,7 @@ func TestServiceModifyNotifications(t *testing.T) {
|
||||
s.SetProxyController(newTestProxyController())
|
||||
const cluster = "proxy.example.com"
|
||||
chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
|
||||
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
chLegacy := registerFakeProxy(s, "legacy", cluster)
|
||||
|
||||
// TLS passthrough works on all proxies regardless of custom port support
|
||||
s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster)
|
||||
@@ -608,7 +621,7 @@ func TestServiceModifyNotifications(t *testing.T) {
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
const cluster = "proxy.example.com"
|
||||
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
chLegacy := registerFakeProxy(s, "legacy", cluster)
|
||||
|
||||
mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
|
||||
mapping.ListenPort = 0 // default port
|
||||
|
||||
@@ -75,7 +75,7 @@ type Manager interface {
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
|
||||
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
|
||||
|
||||
@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
|
||||
}
|
||||
|
||||
// GetGroupByName mocks base method.
|
||||
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
|
||||
ret0, _ := ret[0].(*types.Group)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupByName indicates an expected call of GetGroupByName.
|
||||
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
|
||||
}
|
||||
|
||||
// GetIdentityProvider mocks base method.
|
||||
|
||||
@@ -3138,7 +3138,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil))
|
||||
|
||||
return manager, updateManager, nil
|
||||
}
|
||||
|
||||
61
management/server/activity/store/sql_store_idp_migration.go
Normal file
61
management/server/activity/store/sql_store_idp_migration.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package store
|
||||
|
||||
// This file contains migration-only methods on Store.
|
||||
// They satisfy the migration.MigrationEventStore interface via duck typing.
|
||||
// Delete this file when migration tooling is no longer needed.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp/migration"
|
||||
)
|
||||
|
||||
// CheckSchema verifies that all tables and columns required by the migration exist in the event database.
|
||||
func (store *Store) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError {
|
||||
migrator := store.db.Migrator()
|
||||
var errs []migration.SchemaError
|
||||
|
||||
for _, check := range checks {
|
||||
if !migrator.HasTable(check.Table) {
|
||||
errs = append(errs, migration.SchemaError{Table: check.Table})
|
||||
continue
|
||||
}
|
||||
for _, col := range check.Columns {
|
||||
if !migrator.HasColumn(check.Table, col) {
|
||||
errs = append(errs, migration.SchemaError{Table: check.Table, Column: col})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
// UpdateUserID updates all references to oldUserID in events and deleted_users tables.
|
||||
func (store *Store) UpdateUserID(ctx context.Context, oldUserID, newUserID string) error {
|
||||
return store.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(&activity.Event{}).
|
||||
Where("initiator_id = ?", oldUserID).
|
||||
Update("initiator_id", newUserID).Error; err != nil {
|
||||
return fmt.Errorf("update events.initiator_id: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Model(&activity.Event{}).
|
||||
Where("target_id = ?", oldUserID).
|
||||
Update("target_id", newUserID).Error; err != nil {
|
||||
return fmt.Errorf("update events.target_id: %w", err)
|
||||
}
|
||||
|
||||
// Raw exec: GORM can't update a PK via Model().Update()
|
||||
if err := tx.Exec(
|
||||
"UPDATE deleted_users SET id = ? WHERE id = ?", newUserID, oldUserID,
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("update deleted_users.id: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
161
management/server/activity/store/sql_store_idp_migration_test.go
Normal file
161
management/server/activity/store/sql_store_idp_migration_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
func TestUpdateUserID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
newStore := func(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
key, _ := crypt.GenerateKey()
|
||||
s, err := NewSqlStore(ctx, t.TempDir(), key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { s.Close(ctx) }) //nolint
|
||||
return s
|
||||
}
|
||||
|
||||
t.Run("updates initiator_id in events", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
accountID := "account_1"
|
||||
|
||||
_, err := store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "old-user",
|
||||
TargetID: "some-peer",
|
||||
AccountID: accountID,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = store.UpdateUserID(ctx, "old-user", "new-user")
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := store.Get(ctx, accountID, 0, 10, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, "new-user", result[0].InitiatorID)
|
||||
})
|
||||
|
||||
t.Run("updates target_id in events", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
accountID := "account_1"
|
||||
|
||||
_, err := store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "some-admin",
|
||||
TargetID: "old-user",
|
||||
AccountID: accountID,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = store.UpdateUserID(ctx, "old-user", "new-user")
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := store.Get(ctx, accountID, 0, 10, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, "new-user", result[0].TargetID)
|
||||
})
|
||||
|
||||
t.Run("updates deleted_users id", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
accountID := "account_1"
|
||||
|
||||
// Save an event with email/name meta to create a deleted_users row for "old-user"
|
||||
_, err := store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "admin",
|
||||
TargetID: "old-user",
|
||||
AccountID: accountID,
|
||||
Meta: map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = store.UpdateUserID(ctx, "old-user", "new-user")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Save another event referencing new-user with email/name meta.
|
||||
// This should upsert (not conflict) because the PK was already migrated.
|
||||
_, err = store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "admin",
|
||||
TargetID: "new-user",
|
||||
AccountID: accountID,
|
||||
Meta: map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// The deleted user info should be retrievable via Get (joined on target_id)
|
||||
result, err := store.Get(ctx, accountID, 0, 10, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
for _, ev := range result {
|
||||
assert.Equal(t, "new-user", ev.TargetID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no-op when old user ID does not exist", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
|
||||
err := store.UpdateUserID(ctx, "nonexistent-user", "new-user")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("only updates matching user leaves others unchanged", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
accountID := "account_1"
|
||||
|
||||
_, err := store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "user-a",
|
||||
TargetID: "peer-1",
|
||||
AccountID: accountID,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "user-b",
|
||||
TargetID: "peer-2",
|
||||
AccountID: accountID,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = store.UpdateUserID(ctx, "user-a", "user-a-new")
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := store.Get(ctx, accountID, 0, 10, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
|
||||
for _, ev := range result {
|
||||
if ev.TargetID == "peer-1" {
|
||||
assert.Equal(t, "user-a-new", ev.InitiatorID)
|
||||
} else {
|
||||
assert.Equal(t, "user-b", ev.InitiatorID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -33,15 +33,20 @@ type manager struct {
|
||||
extractor *nbjwt.ClaimsExtractor
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager {
|
||||
// @note if invalid/missing parameters are sent the validator will instantiate
|
||||
// but it will fail when validating and parsing the token
|
||||
jwtValidator := nbjwt.NewValidator(
|
||||
issuer,
|
||||
allAudiences,
|
||||
keysLocation,
|
||||
idpRefreshKeys,
|
||||
)
|
||||
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool, keyFetcher nbjwt.KeyFetcher) Manager {
|
||||
var jwtValidator *nbjwt.Validator
|
||||
if keyFetcher != nil {
|
||||
jwtValidator = nbjwt.NewValidatorWithKeyFetcher(issuer, allAudiences, keyFetcher)
|
||||
} else {
|
||||
// @note if invalid/missing parameters are sent the validator will instantiate
|
||||
// but it will fail when validating and parsing the token
|
||||
jwtValidator = nbjwt.NewValidator(
|
||||
issuer,
|
||||
allAudiences,
|
||||
keysLocation,
|
||||
idpRefreshKeys,
|
||||
)
|
||||
}
|
||||
|
||||
claimsExtractor := nbjwt.NewClaimsExtractor(
|
||||
nbjwt.WithAudience(audience),
|
||||
|
||||
@@ -52,7 +52,7 @@ func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||
|
||||
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
|
||||
if err != nil {
|
||||
@@ -92,7 +92,7 @@ func TestAuthManager_MarkPATUsed(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||
|
||||
err = manager.MarkPATUsed(context.Background(), "tokenId")
|
||||
if err != nil {
|
||||
@@ -142,7 +142,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
|
||||
// these tests only assert groups are parsed from token as per account settings
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
|
||||
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||
@@ -225,7 +225,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
||||
keyId := "test-key"
|
||||
|
||||
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
|
||||
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
|
||||
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false, nil)
|
||||
|
||||
customClaim := func(name string) string {
|
||||
return fmt.Sprintf("%s/%s", audience, name)
|
||||
|
||||
@@ -130,6 +130,10 @@ func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) {
|
||||
gl.mux.RLock()
|
||||
defer gl.mux.RUnlock()
|
||||
|
||||
if gl.db == nil {
|
||||
return nil, fmt.Errorf("geolocation database is not available")
|
||||
}
|
||||
|
||||
var record Record
|
||||
err := gl.db.Lookup(ip, &record)
|
||||
if err != nil {
|
||||
@@ -173,8 +177,14 @@ func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, er
|
||||
|
||||
func (gl *geolocationImpl) Stop() error {
|
||||
close(gl.stopCh)
|
||||
if gl.db != nil {
|
||||
if err := gl.db.Close(); err != nil {
|
||||
|
||||
gl.mux.Lock()
|
||||
db := gl.db
|
||||
gl.db = nil
|
||||
gl.mux.Unlock()
|
||||
|
||||
if db != nil {
|
||||
if err := db.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,10 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
groupName := r.URL.Query().Get("name")
|
||||
if groupName != "" {
|
||||
// Get single group by name
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -118,7 +118,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -71,7 +71,7 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
|
||||
|
||||
return groups, nil
|
||||
},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
|
||||
if groupName == "All" {
|
||||
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func Test_Accounts_GetAll(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, true},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get all accounts", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/accounts", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.Account{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, len(got))
|
||||
account := got[0]
|
||||
assert.Equal(t, "test.com", account.Domain)
|
||||
assert.Equal(t, "private", account.DomainCategory)
|
||||
assert.Equal(t, true, account.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, 86400, account.Settings.PeerLoginExpiration)
|
||||
assert.Equal(t, false, account.Settings.RegularUsersViewBlocked)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Accounts_Update(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
trueVal := true
|
||||
falseVal := false
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestBody *api.AccountRequest
|
||||
verifyResponse func(t *testing.T, account *api.Account)
|
||||
verifyDB func(t *testing.T, account *types.Account)
|
||||
}{
|
||||
{
|
||||
name: "Disable peer login expiration",
|
||||
requestBody: &api.AccountRequest{
|
||||
Settings: api.AccountSettings{
|
||||
PeerLoginExpirationEnabled: false,
|
||||
PeerLoginExpiration: 86400,
|
||||
},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, account *api.Account) {
|
||||
t.Helper()
|
||||
assert.Equal(t, false, account.Settings.PeerLoginExpirationEnabled)
|
||||
},
|
||||
verifyDB: func(t *testing.T, dbAccount *types.Account) {
|
||||
t.Helper()
|
||||
assert.Equal(t, false, dbAccount.Settings.PeerLoginExpirationEnabled)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Update peer login expiration to 48h",
|
||||
requestBody: &api.AccountRequest{
|
||||
Settings: api.AccountSettings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: 172800,
|
||||
},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, account *api.Account) {
|
||||
t.Helper()
|
||||
assert.Equal(t, 172800, account.Settings.PeerLoginExpiration)
|
||||
},
|
||||
verifyDB: func(t *testing.T, dbAccount *types.Account) {
|
||||
t.Helper()
|
||||
assert.Equal(t, 172800*time.Second, dbAccount.Settings.PeerLoginExpiration)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Enable regular users view blocked",
|
||||
requestBody: &api.AccountRequest{
|
||||
Settings: api.AccountSettings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: 86400,
|
||||
RegularUsersViewBlocked: true,
|
||||
},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, account *api.Account) {
|
||||
t.Helper()
|
||||
assert.Equal(t, true, account.Settings.RegularUsersViewBlocked)
|
||||
},
|
||||
verifyDB: func(t *testing.T, dbAccount *types.Account) {
|
||||
t.Helper()
|
||||
assert.Equal(t, true, dbAccount.Settings.RegularUsersViewBlocked)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Enable groups propagation",
|
||||
requestBody: &api.AccountRequest{
|
||||
Settings: api.AccountSettings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: 86400,
|
||||
GroupsPropagationEnabled: &trueVal,
|
||||
},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, account *api.Account) {
|
||||
t.Helper()
|
||||
assert.NotNil(t, account.Settings.GroupsPropagationEnabled)
|
||||
assert.Equal(t, true, *account.Settings.GroupsPropagationEnabled)
|
||||
},
|
||||
verifyDB: func(t *testing.T, dbAccount *types.Account) {
|
||||
t.Helper()
|
||||
assert.Equal(t, true, dbAccount.Settings.GroupsPropagationEnabled)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Enable JWT groups",
|
||||
requestBody: &api.AccountRequest{
|
||||
Settings: api.AccountSettings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: 86400,
|
||||
GroupsPropagationEnabled: &falseVal,
|
||||
JwtGroupsEnabled: &trueVal,
|
||||
JwtGroupsClaimName: stringPointer("groups"),
|
||||
},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, account *api.Account) {
|
||||
t.Helper()
|
||||
assert.NotNil(t, account.Settings.JwtGroupsEnabled)
|
||||
assert.Equal(t, true, *account.Settings.JwtGroupsEnabled)
|
||||
assert.NotNil(t, account.Settings.JwtGroupsClaimName)
|
||||
assert.Equal(t, "groups", *account.Settings.JwtGroupsClaimName)
|
||||
},
|
||||
verifyDB: func(t *testing.T, dbAccount *types.Account) {
|
||||
t.Helper()
|
||||
assert.Equal(t, true, dbAccount.Settings.JWTGroupsEnabled)
|
||||
assert.Equal(t, "groups", dbAccount.Settings.JWTGroupsClaimName)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/accounts/{accountId}", "{accountId}", testing_tools.TestAccountId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := &api.Account{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, testing_tools.TestAccountId, got.Id)
|
||||
assert.Equal(t, "test.com", got.Domain)
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbAccount := testing_tools.VerifyAccountSettings(t, db)
|
||||
tc.verifyDB(t, dbAccount)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stringPointer(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -0,0 +1,554 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func Test_Nameservers_GetAll(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get all nameservers", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/nameservers", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := []api.NameserverGroup{}
|
||||
if err := json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, len(got))
|
||||
assert.Equal(t, "testNSGroup", got[0].Name)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Nameservers_GetById(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
nsGroupId string
|
||||
expectedStatus int
|
||||
expectGroup bool
|
||||
}{
|
||||
{
|
||||
name: "Get existing nameserver group",
|
||||
nsGroupId: "testNSGroupId",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectGroup: true,
|
||||
},
|
||||
{
|
||||
name: "Get non-existing nameserver group",
|
||||
nsGroupId: "nonExistingNSGroupId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectGroup: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expectGroup {
|
||||
got := &api.NameserverGroup{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.Equal(t, "testNSGroupId", got.Id)
|
||||
assert.Equal(t, "testNSGroup", got.Name)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Nameservers_Create(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
requestBody *api.PostApiDnsNameserversJSONRequestBody
|
||||
expectedStatus int
|
||||
verifyResponse func(t *testing.T, nsGroup *api.NameserverGroup)
|
||||
}{
|
||||
{
|
||||
name: "Create nameserver group with single NS",
|
||||
requestBody: &api.PostApiDnsNameserversJSONRequestBody{
|
||||
Name: "newNSGroup",
|
||||
Description: "a new nameserver group",
|
||||
Nameservers: []api.Nameserver{
|
||||
{Ip: "8.8.8.8", NsType: "udp", Port: 53},
|
||||
},
|
||||
Groups: []string{testing_tools.TestGroupId},
|
||||
Primary: false,
|
||||
Domains: []string{"test.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: false,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) {
|
||||
t.Helper()
|
||||
assert.NotEmpty(t, nsGroup.Id)
|
||||
assert.Equal(t, "newNSGroup", nsGroup.Name)
|
||||
assert.Equal(t, 1, len(nsGroup.Nameservers))
|
||||
assert.Equal(t, false, nsGroup.Primary)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create primary nameserver group",
|
||||
requestBody: &api.PostApiDnsNameserversJSONRequestBody{
|
||||
Name: "primaryNS",
|
||||
Description: "primary nameserver",
|
||||
Nameservers: []api.Nameserver{
|
||||
{Ip: "1.1.1.1", NsType: "udp", Port: 53},
|
||||
},
|
||||
Groups: []string{testing_tools.TestGroupId},
|
||||
Primary: true,
|
||||
Domains: []string{},
|
||||
Enabled: true,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) {
|
||||
t.Helper()
|
||||
assert.Equal(t, true, nsGroup.Primary)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create nameserver group with empty groups",
|
||||
requestBody: &api.PostApiDnsNameserversJSONRequestBody{
|
||||
Name: "emptyGroupsNS",
|
||||
Description: "no groups",
|
||||
Nameservers: []api.Nameserver{
|
||||
{Ip: "8.8.8.8", NsType: "udp", Port: 53},
|
||||
},
|
||||
Groups: []string{},
|
||||
Primary: true,
|
||||
Domains: []string{},
|
||||
Enabled: true,
|
||||
},
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/dns/nameservers", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.verifyResponse != nil {
|
||||
got := &api.NameserverGroup{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify the created NS group directly in the DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbNS := testing_tools.VerifyNSGroupInDB(t, db, got.Id)
|
||||
assert.Equal(t, got.Name, dbNS.Name)
|
||||
assert.Equal(t, got.Primary, dbNS.Primary)
|
||||
assert.Equal(t, len(got.Nameservers), len(dbNS.NameServers))
|
||||
assert.Equal(t, got.Enabled, dbNS.Enabled)
|
||||
assert.Equal(t, got.SearchDomainsEnabled, dbNS.SearchDomainsEnabled)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Nameservers_Update(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
nsGroupId string
|
||||
requestBody *api.PutApiDnsNameserversNsgroupIdJSONRequestBody
|
||||
expectedStatus int
|
||||
verifyResponse func(t *testing.T, nsGroup *api.NameserverGroup)
|
||||
}{
|
||||
{
|
||||
name: "Update nameserver group name",
|
||||
nsGroupId: "testNSGroupId",
|
||||
requestBody: &api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
|
||||
Name: "updatedNSGroup",
|
||||
Description: "updated description",
|
||||
Nameservers: []api.Nameserver{
|
||||
{Ip: "1.1.1.1", NsType: "udp", Port: 53},
|
||||
},
|
||||
Groups: []string{testing_tools.TestGroupId},
|
||||
Primary: false,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: false,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, nsGroup *api.NameserverGroup) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "updatedNSGroup", nsGroup.Name)
|
||||
assert.Equal(t, "updated description", nsGroup.Description)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Update non-existing nameserver group",
|
||||
nsGroupId: "nonExistingNSGroupId",
|
||||
requestBody: &api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
|
||||
Name: "whatever",
|
||||
Nameservers: []api.Nameserver{
|
||||
{Ip: "1.1.1.1", NsType: "udp", Port: 53},
|
||||
},
|
||||
Groups: []string{testing_tools.TestGroupId},
|
||||
Primary: true,
|
||||
Domains: []string{},
|
||||
Enabled: true,
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.verifyResponse != nil {
|
||||
got := &api.NameserverGroup{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify the updated NS group directly in the DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbNS := testing_tools.VerifyNSGroupInDB(t, db, tc.nsGroupId)
|
||||
assert.Equal(t, "updatedNSGroup", dbNS.Name)
|
||||
assert.Equal(t, "updated description", dbNS.Description)
|
||||
assert.Equal(t, false, dbNS.Primary)
|
||||
assert.Equal(t, true, dbNS.Enabled)
|
||||
assert.Equal(t, 1, len(dbNS.NameServers))
|
||||
assert.Equal(t, false, dbNS.SearchDomainsEnabled)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Nameservers_Delete(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
nsGroupId string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Delete existing nameserver group",
|
||||
nsGroupId: "testNSGroupId",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete non-existing nameserver group",
|
||||
nsGroupId: "nonExistingNSGroupId",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/dns/nameservers/{nsgroupId}", "{nsgroupId}", tc.nsGroupId, 1), user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
|
||||
// Verify deletion in DB for successful deletes by privileged users
|
||||
if tc.expectedStatus == http.StatusOK && user.expectResponse {
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
testing_tools.VerifyNSGroupNotInDB(t, db, tc.nsGroupId)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DnsSettings_Get(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - Get DNS settings", func(t *testing.T) {
|
||||
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/settings", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
got := &api.DNSSettings{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.NotNil(t, got.DisabledManagementGroups)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DnsSettings_Update(t *testing.T) {
|
||||
users := []struct {
|
||||
name string
|
||||
userId string
|
||||
expectResponse bool
|
||||
}{
|
||||
{"Regular user", testing_tools.TestUserId, false},
|
||||
{"Admin user", testing_tools.TestAdminId, true},
|
||||
{"Owner user", testing_tools.TestOwnerId, true},
|
||||
{"Regular service user", testing_tools.TestServiceUserId, false},
|
||||
{"Admin service user", testing_tools.TestServiceAdminId, true},
|
||||
{"Blocked user", testing_tools.BlockedUserId, false},
|
||||
{"Other user", testing_tools.OtherUserId, false},
|
||||
{"Invalid token", testing_tools.InvalidToken, false},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
requestBody *api.PutApiDnsSettingsJSONRequestBody
|
||||
expectedStatus int
|
||||
verifyResponse func(t *testing.T, settings *api.DNSSettings)
|
||||
expectedDBDisabledMgmtLen int
|
||||
expectedDBDisabledMgmtItem string
|
||||
}{
|
||||
{
|
||||
name: "Update disabled management groups",
|
||||
requestBody: &api.PutApiDnsSettingsJSONRequestBody{
|
||||
DisabledManagementGroups: []string{testing_tools.TestGroupId},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, settings *api.DNSSettings) {
|
||||
t.Helper()
|
||||
assert.Equal(t, 1, len(settings.DisabledManagementGroups))
|
||||
assert.Equal(t, testing_tools.TestGroupId, settings.DisabledManagementGroups[0])
|
||||
},
|
||||
expectedDBDisabledMgmtLen: 1,
|
||||
expectedDBDisabledMgmtItem: testing_tools.TestGroupId,
|
||||
},
|
||||
{
|
||||
name: "Update with empty disabled management groups",
|
||||
requestBody: &api.PutApiDnsSettingsJSONRequestBody{
|
||||
DisabledManagementGroups: []string{},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, settings *api.DNSSettings) {
|
||||
t.Helper()
|
||||
assert.Equal(t, 0, len(settings.DisabledManagementGroups))
|
||||
},
|
||||
expectedDBDisabledMgmtLen: 0,
|
||||
},
|
||||
{
|
||||
name: "Update with non-existing group",
|
||||
requestBody: &api.PutApiDnsSettingsJSONRequestBody{
|
||||
DisabledManagementGroups: []string{"nonExistingGroupId"},
|
||||
},
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns.sql", nil, false)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/dns/settings", user.userId)
|
||||
recorder := httptest.NewRecorder()
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
|
||||
if !expectResponse {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.verifyResponse != nil {
|
||||
got := &api.DNSSettings{}
|
||||
if err := json.Unmarshal(content, got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
tc.verifyResponse(t, got)
|
||||
|
||||
// Verify DNS settings directly in the DB
|
||||
db := testing_tools.GetDB(t, am.GetStore())
|
||||
dbAccount := testing_tools.VerifyAccountSettings(t, db)
|
||||
assert.Equal(t, tc.expectedDBDisabledMgmtLen, len(dbAccount.DNSSettings.DisabledManagementGroups))
|
||||
if tc.expectedDBDisabledMgmtItem != "" {
|
||||
assert.Contains(t, dbAccount.DNSSettings.DisabledManagementGroups, tc.expectedDBDisabledMgmtItem)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user