mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Compare commits
14 Commits
fix-reused
...
fix/filter
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a4d377066 | ||
|
|
92d5418c02 | ||
|
|
09da089a90 | ||
|
|
11eb725ac8 | ||
|
|
30c02ab78c | ||
|
|
3acd86e346 | ||
|
|
5c20f13c48 | ||
|
|
e6587b071d | ||
|
|
85451ab4cd | ||
|
|
a7f3ba03eb | ||
|
|
4f0a3a77ad | ||
|
|
44655ca9b5 | ||
|
|
e601278117 | ||
|
|
8e7b016be2 |
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
51
.github/workflows/pr-title-check.yml
vendored
Normal file
51
.github/workflows/pr-title-check.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: PR Title Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, synchronize, reopened]
|
||||
|
||||
jobs:
|
||||
check-title:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate PR title prefix
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
const allowedTags = [
|
||||
'management',
|
||||
'client',
|
||||
'signal',
|
||||
'proxy',
|
||||
'relay',
|
||||
'misc',
|
||||
'infrastructure',
|
||||
'self-hosted',
|
||||
'doc',
|
||||
];
|
||||
|
||||
const pattern = /^\[([^\]]+)\]\s+.+/;
|
||||
const match = title.match(pattern);
|
||||
|
||||
if (!match) {
|
||||
core.setFailed(
|
||||
`PR title must start with a tag in brackets.\n` +
|
||||
`Example: [client] fix something\n` +
|
||||
`Allowed tags: ${allowedTags.join(', ')}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
|
||||
|
||||
const invalid = tags.filter(t => !allowedTags.includes(t));
|
||||
if (invalid.length > 0) {
|
||||
core.setFailed(
|
||||
`Invalid tag(s): ${invalid.join(', ')}\n` +
|
||||
`Allowed tags: ${allowedTags.join(', ')}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`Valid PR title tags: [${tags.join(', ')}]`);
|
||||
@@ -22,6 +22,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// cgnatPrefix is the RFC 6598 Carrier-Grade NAT range (100.64.0.0/10).
|
||||
// Addresses in this range are used by CNI plugins (Cilium, Calico, etc.) for pod networking
|
||||
// and are not suitable for direct peer-to-peer connectivity between hosts.
|
||||
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
|
||||
|
||||
// FilterFn is a function that filters out candidates based on the address.
|
||||
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
|
||||
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
|
||||
@@ -175,6 +180,15 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
// Filter addresses in the RFC 6598 CGNAT range (100.64.0.0/10) that are not part of the
|
||||
// NetBird WireGuard network. These addresses are commonly assigned by Kubernetes CNI plugins
|
||||
// (Cilium, Calico, etc.) for pod networking and are not routable between hosts.
|
||||
if cgnatPrefix.Contains(a) && !u.address.Network.Contains(a) {
|
||||
u.addrCache.Store(addr.String(), true)
|
||||
log.Infof("Address %s is in the CGNAT range (%s), likely a CNI pod address, refusing to write", addr, cgnatPrefix)
|
||||
return fmt.Errorf("address %s is in the CGNAT range (%s), refusing to write", addr, cgnatPrefix)
|
||||
}
|
||||
|
||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||
} else {
|
||||
|
||||
@@ -165,6 +165,10 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
|
||||
return
|
||||
}
|
||||
|
||||
if candidateInCGNAT(candidate, w.config.WgConfig.WgInterface.Address().Network) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
|
||||
w.log.Errorf("error while handling remote candidate")
|
||||
return
|
||||
@@ -362,6 +366,10 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
|
||||
return
|
||||
}
|
||||
|
||||
if candidateInCGNAT(candidate, w.config.WgConfig.WgInterface.Address().Network) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
|
||||
w.log.Debugf("discovered local candidate %s", candidate.String())
|
||||
go func() {
|
||||
@@ -496,6 +504,11 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
|
||||
return ec, nil
|
||||
}
|
||||
|
||||
// cgnatPrefix is the RFC 6598 Carrier-Grade NAT range (100.64.0.0/10).
|
||||
// Addresses in this range are used by CNI plugins (Cilium, Calico, etc.) for pod networking
|
||||
// and are not suitable for direct peer-to-peer connectivity between hosts.
|
||||
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
|
||||
|
||||
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
@@ -524,6 +537,32 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool
|
||||
return false
|
||||
}
|
||||
|
||||
// candidateInCGNAT checks if a candidate address falls within the RFC 6598 CGNAT range (100.64.0.0/10).
|
||||
// These addresses are commonly used by Kubernetes CNI plugins (Cilium, Calico) for pod networking
|
||||
// and are not routable between hosts, making them unsuitable as ICE candidates.
|
||||
// The wgNetwork parameter is the NetBird WireGuard network prefix — if the candidate address is within
|
||||
// this network, it is not filtered here (it's handled separately by the NetBird network check).
|
||||
func candidateInCGNAT(candidate ice.Candidate, wgNetwork netip.Prefix) bool {
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !cgnatPrefix.Contains(addr) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Don't filter if the address is within the WireGuard network itself —
|
||||
// that's handled by the NetBird network membership check elsewhere.
|
||||
if wgNetwork.IsValid() && wgNetwork.Contains(addr) {
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debugf("Ignoring candidate [%s], its address %s is in the CGNAT range (%s) likely assigned by a CNI plugin",
|
||||
candidate.String(), addr, cgnatPrefix)
|
||||
return true
|
||||
}
|
||||
|
||||
func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -22,7 +20,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -42,6 +39,8 @@ const (
|
||||
var DefaultInterfaceBlacklist = []string{
|
||||
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
// Kubernetes CNI interfaces
|
||||
"cilium_", "cilium", "lxc", "cali", "flannel", "cni", "weave",
|
||||
}
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
@@ -98,6 +97,7 @@ type Config struct {
|
||||
WgPort int
|
||||
NetworkMonitor *bool
|
||||
IFaceBlackList []string
|
||||
IFaceBlackListAppliedDefaults []string `json:",omitempty"`
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
@@ -357,10 +357,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if len(config.IFaceBlackList) == 0 {
|
||||
log.Infof("filling in interface blacklist with defaults: [ %s ]",
|
||||
strings.Join(DefaultInterfaceBlacklist, " "))
|
||||
config.IFaceBlackList = append(config.IFaceBlackList, DefaultInterfaceBlacklist...)
|
||||
if changed := config.mergeDefaultIFaceBlacklist(); changed {
|
||||
updated = true
|
||||
}
|
||||
|
||||
@@ -594,6 +591,37 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// mergeDefaultIFaceBlacklist ensures that new entries added to DefaultInterfaceBlacklist
|
||||
// are merged into an existing IFaceBlackList on upgrade, while respecting entries that
|
||||
// the user deliberately removed. It tracks which defaults have been offered via
|
||||
// IFaceBlackListAppliedDefaults so removals are not undone.
|
||||
func (config *Config) mergeDefaultIFaceBlacklist() (updated bool) {
|
||||
if len(config.IFaceBlackList) == 0 {
|
||||
log.Infof("filling in interface blacklist with defaults: [ %s ]",
|
||||
strings.Join(DefaultInterfaceBlacklist, " "))
|
||||
config.IFaceBlackList = append(config.IFaceBlackList, DefaultInterfaceBlacklist...)
|
||||
config.IFaceBlackListAppliedDefaults = append([]string{}, DefaultInterfaceBlacklist...)
|
||||
return true
|
||||
}
|
||||
|
||||
// Find defaults not yet tracked in AppliedDefaults — these are genuinely new.
|
||||
// Entries already in AppliedDefaults were either kept or deliberately removed by the user.
|
||||
newDefaults := util.SliceDiff(DefaultInterfaceBlacklist, config.IFaceBlackListAppliedDefaults)
|
||||
if len(newDefaults) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only add entries not already present in the blacklist (avoid duplicates)
|
||||
toAdd := util.SliceDiff(newDefaults, config.IFaceBlackList)
|
||||
if len(toAdd) > 0 {
|
||||
log.Infof("merging new default interface blacklist entries: [ %s ]",
|
||||
strings.Join(toAdd, " "))
|
||||
config.IFaceBlackList = append(config.IFaceBlackList, toAdd...)
|
||||
}
|
||||
config.IFaceBlackListAppliedDefaults = append(config.IFaceBlackListAppliedDefaults, newDefaults...)
|
||||
return true
|
||||
}
|
||||
|
||||
// parseURL parses and validates a service URL
|
||||
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
||||
@@ -639,290 +667,3 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
|
||||
}
|
||||
|
||||
return update(input)
|
||||
}
|
||||
|
||||
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
err = util.EnforcePermission(input.ConfigPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
return update(input)
|
||||
}
|
||||
|
||||
func update(input ConfigInput) (*Config, error) {
|
||||
config := &Config{}
|
||||
|
||||
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := config.apply(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// GetConfig read config file and return with Config and if it was created. Errors out if it does not exist
|
||||
func GetConfig(configPath string) (*Config, error) {
|
||||
return readConfig(configPath, false)
|
||||
}
|
||||
|
||||
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||
// The check is performed only for the NetBird's managed version.
|
||||
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedOldDefaultManagementURL, err := parseURL("Management URL", oldDefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() &&
|
||||
config.ManagementURL.Hostname() != parsedOldDefaultManagementURL.Hostname() {
|
||||
// only do the check for the NetBird's managed version
|
||||
return config, nil
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
if !mgmTlsEnabled {
|
||||
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
if config.ManagementURL.Port() != managementLegacyPortString &&
|
||||
config.ManagementURL.Hostname() == defaultManagementURL.Hostname() {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// here we check whether we could switch from the legacy 33073 port to the new 443
|
||||
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
|
||||
config.ManagementURL.String(), newURL.String())
|
||||
key, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
|
||||
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
defer func() {
|
||||
err = client.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// everything is alright => update the config
|
||||
newConfig, err := UpdateConfig(ConfigInput{
|
||||
ManagementURL: newURL.String(),
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, fmt.Errorf("failed updating config file: %v", err)
|
||||
}
|
||||
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
|
||||
|
||||
return newConfig, nil
|
||||
}
|
||||
|
||||
// CreateInMemoryConfig generate a new config but do not write out it to the store
|
||||
func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||
return createNewConfig(input)
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func ReadConfig(configPath string) (*Config, error) {
|
||||
return readConfig(configPath, true)
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
||||
configExists, err := fileExists(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
|
||||
if configExists {
|
||||
err := util.EnforcePermission(configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
|
||||
config := &Config{}
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// initialize through apply() without changes
|
||||
if changed, err := config.apply(ConfigInput{}); err != nil {
|
||||
return nil, err
|
||||
} else if changed {
|
||||
if err = WriteOutConfig(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
} else if !createIfMissing {
|
||||
return nil, fmt.Errorf("config file %s does not exist", configPath)
|
||||
}
|
||||
|
||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = WriteOutConfig(configPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
// WriteOutConfig write put the prepared config to the given path
|
||||
func WriteOutConfig(path string, config *Config) error {
|
||||
return util.WriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// DirectWriteOutConfig writes config directly without atomic temp file operations.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectWriteOutConfig(path string, config *Config) error {
|
||||
return util.DirectWriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
|
||||
// Enforce permissions on existing config files (same as UpdateOrCreateConfig)
|
||||
if err := util.EnforcePermission(input.ConfigPath); err != nil {
|
||||
log.Errorf("failed to enforce permission on config file: %v", err)
|
||||
}
|
||||
|
||||
return directUpdate(input)
|
||||
}
|
||||
|
||||
func directUpdate(input ConfigInput) (*Config, error) {
|
||||
config := &Config{}
|
||||
|
||||
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := config.apply(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// ConfigToJSON serializes a Config struct to a JSON string.
|
||||
// This is useful for exporting config to alternative storage mechanisms
|
||||
// (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||
func ConfigToJSON(config *Config) (string, error) {
|
||||
bs, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bs), nil
|
||||
}
|
||||
|
||||
// ConfigFromJSON deserializes a JSON string to a Config struct.
|
||||
// This is useful for restoring config from alternative storage mechanisms.
|
||||
// After unmarshaling, defaults are applied to ensure the config is fully initialized.
|
||||
func ConfigFromJSON(jsonStr string) (*Config, error) {
|
||||
config := &Config{}
|
||||
err := json.Unmarshal([]byte(jsonStr), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply defaults to ensure required fields are initialized.
|
||||
// This mirrors what readConfig does after loading from file.
|
||||
if _, err := config.apply(ConfigInput{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply defaults to config: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
301
client/internal/profilemanager/config_io.go
Normal file
301
client/internal/profilemanager/config_io.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
|
||||
}
|
||||
|
||||
return update(input)
|
||||
}
|
||||
|
||||
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
err = util.EnforcePermission(input.ConfigPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
return update(input)
|
||||
}
|
||||
|
||||
func update(input ConfigInput) (*Config, error) {
|
||||
config := &Config{}
|
||||
|
||||
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := config.apply(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// GetConfig read config file and return with Config and if it was created. Errors out if it does not exist
|
||||
func GetConfig(configPath string) (*Config, error) {
|
||||
return readConfig(configPath, false)
|
||||
}
|
||||
|
||||
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||
// The check is performed only for the NetBird's managed version.
|
||||
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedOldDefaultManagementURL, err := parseURL("Management URL", oldDefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() &&
|
||||
config.ManagementURL.Hostname() != parsedOldDefaultManagementURL.Hostname() {
|
||||
// only do the check for the NetBird's managed version
|
||||
return config, nil
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
if !mgmTlsEnabled {
|
||||
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
if config.ManagementURL.Port() != managementLegacyPortString &&
|
||||
config.ManagementURL.Hostname() == defaultManagementURL.Hostname() {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// here we check whether we could switch from the legacy 33073 port to the new 443
|
||||
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
|
||||
config.ManagementURL.String(), newURL.String())
|
||||
key, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
|
||||
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
defer func() {
|
||||
err = client.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// everything is alright => update the config
|
||||
newConfig, err := UpdateConfig(ConfigInput{
|
||||
ManagementURL: newURL.String(),
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, fmt.Errorf("failed updating config file: %v", err)
|
||||
}
|
||||
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
|
||||
|
||||
return newConfig, nil
|
||||
}
|
||||
|
||||
// CreateInMemoryConfig generate a new config but do not write out it to the store
|
||||
func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||
return createNewConfig(input)
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func ReadConfig(configPath string) (*Config, error) {
|
||||
return readConfig(configPath, true)
|
||||
}
|
||||
|
||||
// readConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
||||
configExists, err := fileExists(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
|
||||
if configExists {
|
||||
err := util.EnforcePermission(configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
|
||||
config := &Config{}
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// initialize through apply() without changes
|
||||
if changed, err := config.apply(ConfigInput{}); err != nil {
|
||||
return nil, err
|
||||
} else if changed {
|
||||
if err = WriteOutConfig(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
} else if !createIfMissing {
|
||||
return nil, fmt.Errorf("config file %s does not exist", configPath)
|
||||
}
|
||||
|
||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = WriteOutConfig(configPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
// WriteOutConfig write put the prepared config to the given path
|
||||
func WriteOutConfig(path string, config *Config) error {
|
||||
return util.WriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// DirectWriteOutConfig writes config directly without atomic temp file operations.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectWriteOutConfig(path string, config *Config) error {
|
||||
return util.DirectWriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
|
||||
// Enforce permissions on existing config files (same as UpdateOrCreateConfig)
|
||||
if err := util.EnforcePermission(input.ConfigPath); err != nil {
|
||||
log.Errorf("failed to enforce permission on config file: %v", err)
|
||||
}
|
||||
|
||||
return directUpdate(input)
|
||||
}
|
||||
|
||||
func directUpdate(input ConfigInput) (*Config, error) {
|
||||
config := &Config{}
|
||||
|
||||
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := config.apply(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// ConfigToJSON serializes a Config struct to a JSON string.
|
||||
// This is useful for exporting config to alternative storage mechanisms
|
||||
// (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||
func ConfigToJSON(config *Config) (string, error) {
|
||||
bs, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bs), nil
|
||||
}
|
||||
|
||||
// ConfigFromJSON deserializes a JSON string to a Config struct.
|
||||
// This is useful for restoring config from alternative storage mechanisms.
|
||||
// After unmarshaling, defaults are applied to ensure the config is fully initialized.
|
||||
func ConfigFromJSON(jsonStr string) (*Config, error) {
|
||||
config := &Config{}
|
||||
err := json.Unmarshal([]byte(jsonStr), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply defaults to ensure required fields are initialized.
|
||||
// This mirrors what readConfig does after loading from file.
|
||||
if _, err := config.apply(ConfigInput{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply defaults to config: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -108,6 +108,87 @@ func TestExtraIFaceBlackList(t *testing.T) {
|
||||
assert.Contains(t, readConf.(*Config).IFaceBlackList, "eth1")
|
||||
}
|
||||
|
||||
func TestIFaceBlackListMigratesNewDefaults(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
|
||||
// Create a config that simulates an old install with a partial IFaceBlackList
|
||||
// (missing the newer CNI entries like "cilium_", "cali", etc.)
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate an old config that predates AppliedDefaults tracking:
|
||||
// it has only the original entries, no CNI prefixes, and no AppliedDefaults.
|
||||
oldList := []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
|
||||
config.IFaceBlackList = oldList
|
||||
config.IFaceBlackListAppliedDefaults = nil
|
||||
err = WriteOutConfig(configPath, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Re-read the config — apply() should merge in missing defaults
|
||||
reloaded, err := GetConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, entry := range DefaultInterfaceBlacklist {
|
||||
assert.Contains(t, reloaded.IFaceBlackList, entry,
|
||||
"IFaceBlackList should contain default entry %q after migration", entry)
|
||||
}
|
||||
|
||||
// Verify no duplicates were introduced
|
||||
seen := make(map[string]bool)
|
||||
for _, entry := range reloaded.IFaceBlackList {
|
||||
assert.False(t, seen[entry], "duplicate entry %q in IFaceBlackList", entry)
|
||||
seen[entry] = true
|
||||
}
|
||||
|
||||
// AppliedDefaults should now track all current defaults
|
||||
for _, entry := range DefaultInterfaceBlacklist {
|
||||
assert.Contains(t, reloaded.IFaceBlackListAppliedDefaults, entry,
|
||||
"AppliedDefaults should track %q", entry)
|
||||
}
|
||||
|
||||
// Re-read again — should not change (idempotent)
|
||||
reloaded2, err := GetConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, reloaded.IFaceBlackList, reloaded2.IFaceBlackList,
|
||||
"IFaceBlackList should be stable on subsequent reads")
|
||||
}
|
||||
|
||||
func TestIFaceBlackListRespectsUserRemoval(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
|
||||
// Create a fresh config (all defaults applied)
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, config.IFaceBlackList, "cali")
|
||||
|
||||
// User deliberately removes "cali" from their blacklist
|
||||
filtered := make([]string, 0, len(config.IFaceBlackList))
|
||||
for _, entry := range config.IFaceBlackList {
|
||||
if entry != "cali" {
|
||||
filtered = append(filtered, entry)
|
||||
}
|
||||
}
|
||||
config.IFaceBlackList = filtered
|
||||
err = WriteOutConfig(configPath, config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Re-read — "cali" should NOT be re-added because it's in AppliedDefaults
|
||||
reloaded, err := GetConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, reloaded.IFaceBlackList, "cali",
|
||||
"user-removed entry should not be re-added")
|
||||
|
||||
// AppliedDefaults should still contain "cali" (it was offered)
|
||||
assert.Contains(t, reloaded.IFaceBlackListAppliedDefaults, "cali")
|
||||
}
|
||||
|
||||
func TestHiddenPreSharedKey(t *testing.T) {
|
||||
hidden := "**********"
|
||||
samplePreSharedKey := "mysecretpresharedkey"
|
||||
|
||||
@@ -849,14 +849,26 @@ func (s *Server) cleanupConnection() error {
|
||||
if s.actCancel == nil {
|
||||
return ErrServiceNotUp
|
||||
}
|
||||
|
||||
// Capture the engine reference before cancelling the context.
|
||||
// After actCancel(), the connectWithRetryRuns goroutine wakes up
|
||||
// and sets connectClient.engine = nil, causing connectClient.Stop()
|
||||
// to skip the engine shutdown entirely.
|
||||
var engine *internal.Engine
|
||||
if s.connectClient != nil {
|
||||
engine = s.connectClient.Engine()
|
||||
}
|
||||
|
||||
s.actCancel()
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.connectClient.Stop(); err != nil {
|
||||
return err
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.connectClient = nil
|
||||
|
||||
@@ -493,9 +493,6 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
|
||||
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
|
||||
mgmt := cfg.Management
|
||||
|
||||
dnsDomain := mgmt.DnsDomain
|
||||
singleAccModeDomain := dnsDomain
|
||||
|
||||
// Extract port from listen address
|
||||
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
|
||||
if err != nil {
|
||||
@@ -507,8 +504,9 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
||||
mgmtSrv := mgmtServer.NewServer(
|
||||
&mgmtServer.Config{
|
||||
NbConfig: mgmtConfig,
|
||||
DNSDomain: dnsDomain,
|
||||
MgmtSingleAccModeDomain: singleAccModeDomain,
|
||||
DNSDomain: "",
|
||||
MgmtSingleAccModeDomain: "",
|
||||
AutoResolveDomains: true,
|
||||
MgmtPort: mgmtPort,
|
||||
MgmtMetricsPort: cfg.Server.MetricsPort,
|
||||
DisableMetrics: mgmt.DisableAnonymousMetrics,
|
||||
|
||||
@@ -73,7 +73,10 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
service := new(rpservice.Service)
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = service.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
@@ -132,7 +135,10 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
service := new(rpservice.Service)
|
||||
service.ID = serviceID
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = service.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
|
||||
@@ -2,7 +2,7 @@ package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -13,108 +13,20 @@ const (
|
||||
exposeTTL = 90 * time.Second
|
||||
exposeReapInterval = 30 * time.Second
|
||||
maxExposesPerPeer = 10
|
||||
exposeReapBatch = 100
|
||||
)
|
||||
|
||||
type trackedExpose struct {
|
||||
mu sync.Mutex
|
||||
domain string
|
||||
accountID string
|
||||
peerID string
|
||||
lastRenewed time.Time
|
||||
expiring bool
|
||||
type exposeReaper struct {
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
type exposeTracker struct {
|
||||
activeExposes sync.Map
|
||||
exposeCreateMu sync.Mutex
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
func exposeKey(peerID, domain string) string {
|
||||
return peerID + ":" + domain
|
||||
}
|
||||
|
||||
// TrackExposeIfAllowed atomically checks the per-peer limit and registers a new
|
||||
// active expose session under the same lock. Returns (true, false) if the expose
|
||||
// was already tracked (duplicate), (false, true) if tracking succeeded, and
|
||||
// (false, false) if the peer has reached the limit.
|
||||
func (t *exposeTracker) TrackExposeIfAllowed(peerID, domain, accountID string) (alreadyTracked, ok bool) {
|
||||
t.exposeCreateMu.Lock()
|
||||
defer t.exposeCreateMu.Unlock()
|
||||
|
||||
key := exposeKey(peerID, domain)
|
||||
_, loaded := t.activeExposes.LoadOrStore(key, &trackedExpose{
|
||||
domain: domain,
|
||||
accountID: accountID,
|
||||
peerID: peerID,
|
||||
lastRenewed: time.Now(),
|
||||
})
|
||||
if loaded {
|
||||
return true, false
|
||||
}
|
||||
|
||||
if t.CountPeerExposes(peerID) > maxExposesPerPeer {
|
||||
t.activeExposes.Delete(key)
|
||||
return false, false
|
||||
}
|
||||
|
||||
return false, true
|
||||
}
|
||||
|
||||
// UntrackExpose removes an active expose session from tracking.
|
||||
func (t *exposeTracker) UntrackExpose(peerID, domain string) {
|
||||
t.activeExposes.Delete(exposeKey(peerID, domain))
|
||||
}
|
||||
|
||||
// CountPeerExposes returns the number of active expose sessions for a peer.
|
||||
func (t *exposeTracker) CountPeerExposes(peerID string) int {
|
||||
count := 0
|
||||
t.activeExposes.Range(func(_, val any) bool {
|
||||
if expose := val.(*trackedExpose); expose.peerID == peerID {
|
||||
count++
|
||||
}
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
// MaxExposesPerPeer returns the maximum number of concurrent exposes allowed per peer.
|
||||
func (t *exposeTracker) MaxExposesPerPeer() int {
|
||||
return maxExposesPerPeer
|
||||
}
|
||||
|
||||
// RenewTrackedExpose updates the in-memory lastRenewed timestamp for a tracked expose.
|
||||
// Returns false if the expose is not tracked or is being reaped.
|
||||
func (t *exposeTracker) RenewTrackedExpose(peerID, domain string) bool {
|
||||
key := exposeKey(peerID, domain)
|
||||
val, ok := t.activeExposes.Load(key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
if expose.expiring {
|
||||
expose.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
expose.lastRenewed = time.Now()
|
||||
expose.mu.Unlock()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// StopTrackedExpose removes an active expose session from tracking.
|
||||
// Returns false if the expose was not tracked.
|
||||
func (t *exposeTracker) StopTrackedExpose(peerID, domain string) bool {
|
||||
key := exposeKey(peerID, domain)
|
||||
_, ok := t.activeExposes.LoadAndDelete(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
// StartExposeReaper starts a background goroutine that reaps expired expose sessions.
|
||||
func (t *exposeTracker) StartExposeReaper(ctx context.Context) {
|
||||
// StartExposeReaper starts a background goroutine that reaps expired ephemeral services from the DB.
|
||||
func (r *exposeReaper) StartExposeReaper(ctx context.Context) {
|
||||
go func() {
|
||||
// start with a random delay
|
||||
rn := rand.IntN(10)
|
||||
time.Sleep(time.Duration(rn) * time.Second)
|
||||
|
||||
ticker := time.NewTicker(exposeReapInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -123,41 +35,31 @@ func (t *exposeTracker) StartExposeReaper(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
t.reapExpiredExposes()
|
||||
r.reapExpiredExposes(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *exposeTracker) reapExpiredExposes() {
|
||||
t.activeExposes.Range(func(key, val any) bool {
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
expired := time.Since(expose.lastRenewed) > exposeTTL
|
||||
if expired {
|
||||
expose.expiring = true
|
||||
}
|
||||
expose.mu.Unlock()
|
||||
func (r *exposeReaper) reapExpiredExposes(ctx context.Context) {
|
||||
expired, err := r.manager.store.GetExpiredEphemeralServices(ctx, exposeTTL, exposeReapBatch)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get expired ephemeral services: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !expired {
|
||||
return true
|
||||
for _, svc := range expired {
|
||||
log.Infof("reaping expired expose session for peer %s, domain %s", svc.SourcePeer, svc.Domain)
|
||||
|
||||
err := r.manager.deleteExpiredPeerService(ctx, svc.AccountID, svc.SourcePeer, svc.ID)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("reaping expired expose session for peer %s, domain %s", expose.peerID, expose.domain)
|
||||
|
||||
err := t.manager.deleteServiceFromPeer(context.Background(), expose.accountID, expose.peerID, expose.domain, true)
|
||||
|
||||
s, _ := status.FromError(err)
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
t.activeExposes.Delete(key)
|
||||
case s.ErrorType == status.NotFound:
|
||||
log.Debugf("service %s was already deleted", expose.domain)
|
||||
default:
|
||||
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", expose.domain, err)
|
||||
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
|
||||
log.Debugf("service %s was already deleted by another instance", svc.Domain)
|
||||
} else {
|
||||
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", svc.Domain, err)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,184 +10,62 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
func TestExposeKey(t *testing.T) {
|
||||
assert.Equal(t, "peer1:example.com", exposeKey("peer1", "example.com"))
|
||||
assert.Equal(t, "peer2:other.com", exposeKey("peer2", "other.com"))
|
||||
assert.NotEqual(t, exposeKey("peer1", "a.com"), exposeKey("peer1", "b.com"))
|
||||
}
|
||||
|
||||
func TestTrackExposeIfAllowed(t *testing.T) {
|
||||
t.Run("first track succeeds", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
assert.False(t, alreadyTracked, "first track should not be duplicate")
|
||||
assert.True(t, ok, "first track should be allowed")
|
||||
})
|
||||
|
||||
t.Run("duplicate track detected", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
|
||||
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
assert.True(t, alreadyTracked, "second track should be duplicate")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("rejects when at limit", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
for i := range maxExposesPerPeer {
|
||||
_, ok := tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
|
||||
assert.True(t, ok, "track %d should be allowed", i)
|
||||
}
|
||||
|
||||
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "over-limit.com", "acct1")
|
||||
assert.False(t, alreadyTracked)
|
||||
assert.False(t, ok, "should reject when at limit")
|
||||
})
|
||||
|
||||
t.Run("other peer unaffected by limit", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
for i := range maxExposesPerPeer {
|
||||
tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
|
||||
}
|
||||
|
||||
_, ok := tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
|
||||
assert.True(t, ok, "other peer should still be within limit")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUntrackExpose(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
assert.Equal(t, 1, tracker.CountPeerExposes("peer1"))
|
||||
|
||||
tracker.UntrackExpose("peer1", "a.com")
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
|
||||
}
|
||||
|
||||
func TestCountPeerExposes(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
|
||||
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
tracker.TrackExposeIfAllowed("peer1", "b.com", "acct1")
|
||||
tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
|
||||
|
||||
assert.Equal(t, 2, tracker.CountPeerExposes("peer1"), "peer1 should have 2 exposes")
|
||||
assert.Equal(t, 1, tracker.CountPeerExposes("peer2"), "peer2 should have 1 expose")
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes("peer3"), "peer3 should have 0 exposes")
|
||||
}
|
||||
|
||||
func TestMaxExposesPerPeer(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
assert.Equal(t, maxExposesPerPeer, tracker.MaxExposesPerPeer())
|
||||
}
|
||||
|
||||
func TestRenewTrackedExpose(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
|
||||
found := tracker.RenewTrackedExpose("peer1", "a.com")
|
||||
assert.False(t, found, "should not find untracked expose")
|
||||
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
|
||||
found = tracker.RenewTrackedExpose("peer1", "a.com")
|
||||
assert.True(t, found, "should find tracked expose")
|
||||
}
|
||||
|
||||
func TestRenewTrackedExpose_RejectsExpiring(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
|
||||
// Simulate reaper marking the expose as expiring
|
||||
key := exposeKey("peer1", "a.com")
|
||||
val, _ := tracker.activeExposes.Load(key)
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
expose.expiring = true
|
||||
expose.mu.Unlock()
|
||||
|
||||
found := tracker.RenewTrackedExpose("peer1", "a.com")
|
||||
assert.False(t, found, "should reject renewal when expiring")
|
||||
}
|
||||
|
||||
func TestReapExpiredExposes(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
tracker := mgr.exposeTracker
|
||||
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually expire the tracked entry
|
||||
key := exposeKey(testPeerID, resp.Domain)
|
||||
val, _ := tracker.activeExposes.Load(key)
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
// Manually expire the service by backdating meta_last_renewed_at
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
|
||||
// Add an active (non-expired) tracking entry
|
||||
tracker.activeExposes.Store(exposeKey("peer1", "active.com"), &trackedExpose{
|
||||
domain: "active.com",
|
||||
accountID: testAccountID,
|
||||
peerID: "peer1",
|
||||
lastRenewed: time.Now(),
|
||||
// Create a non-expired service
|
||||
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8081,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tracker.reapExpiredExposes()
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
|
||||
_, exists := tracker.activeExposes.Load(key)
|
||||
assert.False(t, exists, "expired expose should be removed")
|
||||
// Expired service should be deleted
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.Error(t, err, "expired service should be deleted")
|
||||
|
||||
_, exists = tracker.activeExposes.Load(exposeKey("peer1", "active.com"))
|
||||
assert.True(t, exists, "active expose should remain")
|
||||
// Non-expired service should remain
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp2.Domain)
|
||||
require.NoError(t, err, "active service should remain")
|
||||
}
|
||||
|
||||
func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
tracker := mgr.exposeTracker
|
||||
|
||||
func TestReapAlreadyDeletedService(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
key := exposeKey(testPeerID, resp.Domain)
|
||||
val, _ := tracker.activeExposes.Load(key)
|
||||
expose := val.(*trackedExpose)
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
|
||||
// Expire it
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
// Delete the service before reaping
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Renew should succeed before reaping
|
||||
assert.True(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should succeed before reaper runs")
|
||||
|
||||
// Re-expire and reap
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
|
||||
tracker.reapExpiredExposes()
|
||||
|
||||
// Entry is deleted, renew returns false
|
||||
assert.False(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should fail after reap")
|
||||
// Reaping should handle the already-deleted service gracefully
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
}
|
||||
|
||||
func TestConcurrentTrackAndCount(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
tracker := mgr.exposeTracker
|
||||
func TestConcurrentReapAndRenew(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := range 5 {
|
||||
@@ -198,59 +76,133 @@ func TestConcurrentTrackAndCount(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Manually expire all tracked entries
|
||||
tracker.activeExposes.Range(func(_, val any) bool {
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
return true
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.reapExpiredExposes()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.CountPeerExposes(testPeerID)
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes(testPeerID), "all expired exposes should be reaped")
|
||||
}
|
||||
|
||||
func TestTrackedExposeMutexProtectsLastRenewed(t *testing.T) {
|
||||
expose := &trackedExpose{
|
||||
lastRenewed: time.Now().Add(-1 * time.Hour),
|
||||
// Expire all services
|
||||
services, err := testStore.GetAccountServices(ctx, store.LockingStrengthNone, testAccountID)
|
||||
require.NoError(t, err)
|
||||
for _, svc := range services {
|
||||
if svc.Source == rpservice.SourceEphemeral {
|
||||
expireEphemeralService(t, testStore, testAccountID, svc.Domain)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 100 {
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now()
|
||||
expose.mu.Unlock()
|
||||
}
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 100 {
|
||||
expose.mu.Lock()
|
||||
_ = time.Since(expose.lastRenewed)
|
||||
expose.mu.Unlock()
|
||||
}
|
||||
_, _ = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
expose.mu.Lock()
|
||||
require.False(t, expose.lastRenewed.IsZero(), "lastRenewed should not be zero after concurrent access")
|
||||
expose.mu.Unlock()
|
||||
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count, "all expired services should be reaped")
|
||||
}
|
||||
|
||||
func TestRenewEphemeralService(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("renew succeeds for active service", func(t *testing.T) {
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8082,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
|
||||
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no active expose session")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCountAndExistsEphemeralServices(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8083,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
exists, err := mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "service should exist")
|
||||
|
||||
exists, err = mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, "no-such.domain")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "non-existent service should not exist")
|
||||
}
|
||||
|
||||
func TestMaxExposesPerPeerEnforced(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := range maxExposesPerPeer {
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8090 + i,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err, "expose %d should succeed", i)
|
||||
}
|
||||
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 9999,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
|
||||
}
|
||||
|
||||
func TestReapSkipsRenewedService(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8086,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expire the service
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
|
||||
// Renew it before the reaper runs
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reaper should skip it because the re-check sees a fresh timestamp
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.NoError(t, err, "renewed service should survive reaping")
|
||||
}
|
||||
|
||||
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
|
||||
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
|
||||
t.Helper()
|
||||
svc, err := s.GetServiceByDomain(context.Background(), domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
expired := time.Now().Add(-2 * exposeTTL)
|
||||
svc.Meta.LastRenewedAt = &expired
|
||||
err = s.UpdateService(context.Background(), svc)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ type Manager struct {
|
||||
permissionsManager permissions.Manager
|
||||
proxyController proxy.Controller
|
||||
clusterDeriver ClusterDeriver
|
||||
exposeTracker *exposeTracker
|
||||
exposeReaper *exposeReaper
|
||||
}
|
||||
|
||||
// NewManager creates a new service manager.
|
||||
@@ -49,13 +49,13 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
|
||||
proxyController: proxyController,
|
||||
clusterDeriver: clusterDeriver,
|
||||
}
|
||||
mgr.exposeTracker = &exposeTracker{manager: mgr}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
return mgr
|
||||
}
|
||||
|
||||
// StartExposeReaper delegates to the expose tracker.
|
||||
// StartExposeReaper starts the background goroutine that reaps expired ephemeral services.
|
||||
func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||
m.exposeTracker.StartExposeReaper(ctx)
|
||||
m.exposeReaper.StartExposeReaper(ctx)
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
@@ -199,7 +199,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
|
||||
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -215,8 +215,54 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, servi
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
||||
// persistNewEphemeralService creates an ephemeral service inside a single transaction
|
||||
// that also enforces the duplicate and per-peer limit checks atomically.
|
||||
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
|
||||
// for the same peer, preventing the per-peer limit from being bypassed.
|
||||
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
// Lock the peer row to serialize concurrent creates for the same peer.
|
||||
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
|
||||
// table returns no rows and acquires no locks, allowing concurrent inserts
|
||||
// to bypass the per-peer limit.
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
|
||||
return fmt.Errorf("lock peer row: %w", err)
|
||||
}
|
||||
|
||||
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check existing expose: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
|
||||
}
|
||||
|
||||
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count peer exposes: %w", err)
|
||||
}
|
||||
if count >= int64(maxExposesPerPeer) {
|
||||
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
|
||||
}
|
||||
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, svc); err != nil {
|
||||
return fmt.Errorf("create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing service: %w", err)
|
||||
@@ -225,7 +271,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
|
||||
}
|
||||
|
||||
if existingService != nil && existingService.ID != excludeServiceID {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
|
||||
return status.Errorf(status.AlreadyExists, "domain already taken")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -306,7 +352,7 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
|
||||
}
|
||||
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -412,10 +458,6 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
|
||||
return err
|
||||
}
|
||||
|
||||
if s.Source == service.SourceEphemeral {
|
||||
m.exposeTracker.UntrackExpose(s.SourcePeer, s.Domain)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
@@ -457,9 +499,6 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
for _, svc := range services {
|
||||
if svc.Source == service.SourceEphemeral {
|
||||
m.exposeTracker.UntrackExpose(svc.SourcePeer, svc.Domain)
|
||||
}
|
||||
m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
|
||||
}
|
||||
@@ -681,26 +720,13 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
svc.Meta.LastRenewedAt = &now
|
||||
svc.SourcePeer = peerID
|
||||
|
||||
if err := m.persistNewService(ctx, accountID, svc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now()
|
||||
svc.Meta.LastRenewedAt = &now
|
||||
|
||||
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, svc.Domain, accountID)
|
||||
if alreadyTracked {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", svc.Domain, err)
|
||||
}
|
||||
return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
|
||||
}
|
||||
if !allowed {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", svc.Domain, err)
|
||||
}
|
||||
return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
|
||||
if err := m.persistNewEphemeralService(ctx, accountID, peerID, svc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||
@@ -748,26 +774,17 @@ func (m *Manager) buildRandomDomain(name string) (string, error) {
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
// RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session.
|
||||
// Returns an error if the expose is not actively tracked.
|
||||
func (m *Manager) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error {
|
||||
if !m.exposeTracker.RenewTrackedExpose(peerID, domain) {
|
||||
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
|
||||
}
|
||||
return nil
|
||||
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
|
||||
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service.
|
||||
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
|
||||
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !m.exposeTracker.StopTrackedExpose(peerID, domain) {
|
||||
log.WithContext(ctx).Warnf("expose tracker entry for domain %s already removed; service was deleted", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -788,7 +805,7 @@ func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID,
|
||||
|
||||
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
|
||||
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
|
||||
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
|
||||
svc, err := m.store.GetServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -848,6 +865,57 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteExpiredPeerService deletes an ephemeral service by ID after re-checking
|
||||
// that it is still expired under a row lock. This prevents deleting a service
|
||||
// that was renewed between the batch query and this delete, and ensures only one
|
||||
// management instance processes the deletion
|
||||
func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerID, serviceID string) error {
|
||||
var svc *service.Service
|
||||
deleted := false
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if svc.Source != service.SourceEphemeral || svc.SourcePeer != peerID {
|
||||
return status.Errorf(status.PermissionDenied, "service does not match expected ephemeral owner")
|
||||
}
|
||||
|
||||
if svc.Meta.LastRenewedAt != nil && time.Since(*svc.Meta.LastRenewedAt) <= exposeTTL {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service: %w", err)
|
||||
}
|
||||
deleted = true
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !deleted {
|
||||
return nil
|
||||
}
|
||||
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
|
||||
peer = nil
|
||||
}
|
||||
|
||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any {
|
||||
if peer == nil {
|
||||
return meta
|
||||
|
||||
@@ -72,7 +72,6 @@ func TestInitializeServiceForCreate(t *testing.T) {
|
||||
|
||||
func TestCheckDomainAvailable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -88,7 +87,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "available.com").
|
||||
GetServiceByDomain(ctx, "available.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
},
|
||||
expectedError: false,
|
||||
@@ -99,7 +98,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
GetServiceByDomain(ctx, "exists.com").
|
||||
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
@@ -111,7 +110,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
excludeServiceID: "service-123",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
GetServiceByDomain(ctx, "exists.com").
|
||||
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: false,
|
||||
@@ -122,7 +121,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
excludeServiceID: "service-456",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
GetServiceByDomain(ctx, "exists.com").
|
||||
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
@@ -134,7 +133,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "error.com").
|
||||
GetServiceByDomain(ctx, "error.com").
|
||||
Return(nil, errors.New("database error"))
|
||||
},
|
||||
expectedError: true,
|
||||
@@ -150,7 +149,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, tt.domain, tt.excludeServiceID)
|
||||
|
||||
if tt.expectedError {
|
||||
require.Error(t, err)
|
||||
@@ -168,7 +167,6 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
|
||||
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("empty domain", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
@@ -176,11 +174,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "").
|
||||
GetServiceByDomain(ctx, "").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, "", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
@@ -191,11 +189,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "test.com").
|
||||
GetServiceByDomain(ctx, "test.com").
|
||||
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, "test.com", "")
|
||||
|
||||
assert.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
@@ -209,11 +207,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "nil.com").
|
||||
GetServiceByDomain(ctx, "nil.com").
|
||||
Return(nil, nil)
|
||||
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, "nil.com", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
@@ -241,7 +239,7 @@ func TestPersistNewService(t *testing.T) {
|
||||
// Create another mock for the transaction
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "new.com").
|
||||
GetServiceByDomain(ctx, "new.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
txMock.EXPECT().
|
||||
CreateService(ctx, service).
|
||||
@@ -272,7 +270,7 @@ func TestPersistNewService(t *testing.T) {
|
||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "existing.com").
|
||||
GetServiceByDomain(ctx, "existing.com").
|
||||
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||
|
||||
return fn(txMock)
|
||||
@@ -425,8 +423,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
t.Helper()
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
t.Cleanup(srv.Close)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
|
||||
@@ -705,8 +704,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
t.Cleanup(proxySrv.Close)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
@@ -720,7 +720,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
domains: []string{"test.netbird.io"},
|
||||
},
|
||||
}
|
||||
mgr.exposeTracker = &exposeTracker{manager: mgr}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
|
||||
return mgr, testStore
|
||||
}
|
||||
@@ -814,7 +814,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
assert.NotEmpty(t, resp.ServiceURL, "service URL should be set")
|
||||
|
||||
// Verify service is persisted in store
|
||||
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
persisted, err := testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.Domain, persisted.Domain)
|
||||
assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
|
||||
@@ -977,7 +977,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify service is deleted
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.Error(t, err, "service should be deleted")
|
||||
})
|
||||
|
||||
@@ -1012,41 +1012,43 @@ func TestStopServiceFromPeer(t *testing.T) {
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.Error(t, err, "service should be deleted")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteService_UntracksEphemeralExpose(t *testing.T) {
|
||||
func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be tracked after create")
|
||||
|
||||
// Look up the service by domain to get its store ID
|
||||
svc, err := mgr.store.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count, "one ephemeral service should exist after create")
|
||||
|
||||
svc, err := testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete via the API path (user-initiated)
|
||||
err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be untracked after API delete")
|
||||
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete")
|
||||
|
||||
// A new expose should succeed (not blocked by stale tracking)
|
||||
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 9090,
|
||||
Protocol: "http",
|
||||
})
|
||||
assert.NoError(t, err, "new expose should succeed after API delete cleared tracking")
|
||||
assert.NoError(t, err, "new expose should succeed after API delete")
|
||||
}
|
||||
|
||||
func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) {
|
||||
func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
|
||||
@@ -1058,12 +1060,16 @@ func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 3, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be tracked")
|
||||
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), count, "all ephemeral services should exist")
|
||||
|
||||
err := mgr.DeleteAllServices(ctx, testAccountID, testUserID)
|
||||
err = mgr.DeleteAllServices(ctx, testAccountID, testUserID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be untracked after DeleteAllServices")
|
||||
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count, "all ephemeral services should be deleted after DeleteAllServices")
|
||||
}
|
||||
|
||||
func TestRenewServiceFromPeer(t *testing.T) {
|
||||
@@ -1130,8 +1136,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
t.Cleanup(proxySrv.Close)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -6,13 +6,16 @@ import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
@@ -49,17 +52,25 @@ const (
|
||||
SourceEphemeral = "ephemeral"
|
||||
)
|
||||
|
||||
type TargetOptions struct {
|
||||
SkipTLSVerify bool `json:"skip_tls_verify"`
|
||||
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
|
||||
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
ID uint `gorm:"primaryKey" json:"-"`
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
Port int `gorm:"index:idx_target_port" json:"port"`
|
||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
||||
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||
ID uint `gorm:"primaryKey" json:"-"`
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
Port int `gorm:"index:idx_target_port" json:"port"`
|
||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
||||
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||
Options TargetOptions `gorm:"embedded" json:"options"`
|
||||
}
|
||||
|
||||
type PasswordAuthConfig struct {
|
||||
@@ -123,7 +134,7 @@ type Service struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Domain string `gorm:"index"`
|
||||
Domain string `gorm:"type:varchar(255);uniqueIndex"`
|
||||
ProxyCluster string `gorm:"index"`
|
||||
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
||||
Enabled bool
|
||||
@@ -133,8 +144,8 @@ type Service struct {
|
||||
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
Source string `gorm:"default:'permanent'"`
|
||||
SourcePeer string
|
||||
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
|
||||
SourcePeer string `gorm:"index:idx_service_source_peer"`
|
||||
}
|
||||
|
||||
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
|
||||
@@ -194,7 +205,7 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
// Convert internal targets to API targets
|
||||
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
||||
for _, target := range s.Targets {
|
||||
apiTargets = append(apiTargets, api.ServiceTarget{
|
||||
st := api.ServiceTarget{
|
||||
Path: target.Path,
|
||||
Host: &target.Host,
|
||||
Port: target.Port,
|
||||
@@ -202,7 +213,9 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
TargetId: target.TargetId,
|
||||
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
||||
Enabled: target.Enabled,
|
||||
})
|
||||
}
|
||||
st.Options = targetOptionsToAPI(target.Options)
|
||||
apiTargets = append(apiTargets, st)
|
||||
}
|
||||
|
||||
meta := api.ServiceMeta{
|
||||
@@ -256,10 +269,14 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
pathMappings = append(pathMappings, &proto.PathMapping{
|
||||
|
||||
pm := &proto.PathMapping{
|
||||
Path: path,
|
||||
Target: targetURL.String(),
|
||||
})
|
||||
}
|
||||
|
||||
pm.Options = targetOptionsToProto(target.Options)
|
||||
pathMappings = append(pathMappings, pm)
|
||||
}
|
||||
|
||||
auth := &proto.Authentication{
|
||||
@@ -312,13 +329,87 @@ func isDefaultPort(scheme string, port int) bool {
|
||||
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
||||
}
|
||||
|
||||
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
||||
// PathRewriteMode controls how the request path is rewritten before forwarding.
|
||||
type PathRewriteMode string
|
||||
|
||||
const (
|
||||
PathRewritePreserve PathRewriteMode = "preserve"
|
||||
)
|
||||
|
||||
func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
|
||||
switch mode {
|
||||
case PathRewritePreserve:
|
||||
return proto.PathRewriteMode_PATH_REWRITE_PRESERVE
|
||||
default:
|
||||
return proto.PathRewriteMode_PATH_REWRITE_DEFAULT
|
||||
}
|
||||
}
|
||||
|
||||
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
||||
return nil
|
||||
}
|
||||
apiOpts := &api.ServiceTargetOptions{}
|
||||
if opts.SkipTLSVerify {
|
||||
apiOpts.SkipTlsVerify = &opts.SkipTLSVerify
|
||||
}
|
||||
if opts.RequestTimeout != 0 {
|
||||
s := opts.RequestTimeout.String()
|
||||
apiOpts.RequestTimeout = &s
|
||||
}
|
||||
if opts.PathRewrite != "" {
|
||||
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
|
||||
apiOpts.PathRewrite = &pr
|
||||
}
|
||||
if len(opts.CustomHeaders) > 0 {
|
||||
apiOpts.CustomHeaders = &opts.CustomHeaders
|
||||
}
|
||||
return apiOpts
|
||||
}
|
||||
|
||||
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
|
||||
return nil
|
||||
}
|
||||
popts := &proto.PathTargetOptions{
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
}
|
||||
if opts.RequestTimeout != 0 {
|
||||
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||
}
|
||||
return popts
|
||||
}
|
||||
|
||||
func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
|
||||
var opts TargetOptions
|
||||
if o.SkipTlsVerify != nil {
|
||||
opts.SkipTLSVerify = *o.SkipTlsVerify
|
||||
}
|
||||
if o.RequestTimeout != nil {
|
||||
d, err := time.ParseDuration(*o.RequestTimeout)
|
||||
if err != nil {
|
||||
return opts, fmt.Errorf("target %d: parse request_timeout %q: %w", idx, *o.RequestTimeout, err)
|
||||
}
|
||||
opts.RequestTimeout = d
|
||||
}
|
||||
if o.PathRewrite != nil {
|
||||
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
|
||||
}
|
||||
if o.CustomHeaders != nil {
|
||||
opts.CustomHeaders = *o.CustomHeaders
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) error {
|
||||
s.Name = req.Name
|
||||
s.Domain = req.Domain
|
||||
s.AccountID = accountID
|
||||
|
||||
targets := make([]*Target, 0, len(req.Targets))
|
||||
for _, apiTarget := range req.Targets {
|
||||
for i, apiTarget := range req.Targets {
|
||||
target := &Target{
|
||||
AccountID: accountID,
|
||||
Path: apiTarget.Path,
|
||||
@@ -331,6 +422,13 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
||||
if apiTarget.Host != nil {
|
||||
target.Host = *apiTarget.Host
|
||||
}
|
||||
if apiTarget.Options != nil {
|
||||
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target.Options = opts
|
||||
}
|
||||
targets = append(targets, target)
|
||||
}
|
||||
s.Targets = targets
|
||||
@@ -368,6 +466,8 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
||||
}
|
||||
s.Auth.BearerAuth = bearerAuth
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Validate() error {
|
||||
@@ -400,11 +500,113 @@ func (s *Service) Validate() error {
|
||||
if target.TargetId == "" {
|
||||
return fmt.Errorf("target %d has empty target_id", i)
|
||||
}
|
||||
if err := validateTargetOptions(i, &target.Options); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
maxRequestTimeout = 5 * time.Minute
|
||||
maxCustomHeaders = 16
|
||||
maxHeaderKeyLen = 128
|
||||
maxHeaderValueLen = 4096
|
||||
)
|
||||
|
||||
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
|
||||
var httpHeaderNameRe = regexp.MustCompile(`^[!#$%&'*+\-.^_` + "`" + `|~0-9A-Za-z]+$`)
|
||||
|
||||
// hopByHopHeaders are headers that must not be set as custom headers
|
||||
// because they are connection-level and stripped by the proxy.
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
"Connection": {},
|
||||
"Keep-Alive": {},
|
||||
"Proxy-Authenticate": {},
|
||||
"Proxy-Authorization": {},
|
||||
"Proxy-Connection": {},
|
||||
"Te": {},
|
||||
"Trailer": {},
|
||||
"Transfer-Encoding": {},
|
||||
"Upgrade": {},
|
||||
}
|
||||
|
||||
// reservedHeaders are set authoritatively by the proxy or control HTTP framing
|
||||
// and cannot be overridden.
|
||||
var reservedHeaders = map[string]struct{}{
|
||||
"Content-Length": {},
|
||||
"Content-Type": {},
|
||||
"Cookie": {},
|
||||
"Forwarded": {},
|
||||
"X-Forwarded-For": {},
|
||||
"X-Forwarded-Host": {},
|
||||
"X-Forwarded-Port": {},
|
||||
"X-Forwarded-Proto": {},
|
||||
"X-Real-Ip": {},
|
||||
}
|
||||
|
||||
func validateTargetOptions(idx int, opts *TargetOptions) error {
|
||||
if opts.PathRewrite != "" && opts.PathRewrite != PathRewritePreserve {
|
||||
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
|
||||
}
|
||||
|
||||
if opts.RequestTimeout != 0 {
|
||||
if opts.RequestTimeout <= 0 {
|
||||
return fmt.Errorf("target %d: request_timeout must be positive", idx)
|
||||
}
|
||||
if opts.RequestTimeout > maxRequestTimeout {
|
||||
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateCustomHeaders(idx int, headers map[string]string) error {
|
||||
if len(headers) > maxCustomHeaders {
|
||||
return fmt.Errorf("target %d: custom_headers count %d exceeds maximum of %d", idx, len(headers), maxCustomHeaders)
|
||||
}
|
||||
seen := make(map[string]string, len(headers))
|
||||
for key, value := range headers {
|
||||
if !httpHeaderNameRe.MatchString(key) {
|
||||
return fmt.Errorf("target %d: custom header key %q is not a valid HTTP header name", idx, key)
|
||||
}
|
||||
if len(key) > maxHeaderKeyLen {
|
||||
return fmt.Errorf("target %d: custom header key %q exceeds maximum length of %d", idx, key, maxHeaderKeyLen)
|
||||
}
|
||||
if len(value) > maxHeaderValueLen {
|
||||
return fmt.Errorf("target %d: custom header %q value exceeds maximum length of %d", idx, key, maxHeaderValueLen)
|
||||
}
|
||||
if containsCRLF(key) || containsCRLF(value) {
|
||||
return fmt.Errorf("target %d: custom header %q contains invalid characters", idx, key)
|
||||
}
|
||||
canonical := http.CanonicalHeaderKey(key)
|
||||
if prev, ok := seen[canonical]; ok {
|
||||
return fmt.Errorf("target %d: custom header keys %q and %q collide (both canonicalize to %q)", idx, prev, key, canonical)
|
||||
}
|
||||
seen[canonical] = key
|
||||
if _, ok := hopByHopHeaders[canonical]; ok {
|
||||
return fmt.Errorf("target %d: custom header %q is a hop-by-hop header and cannot be set", idx, key)
|
||||
}
|
||||
if _, ok := reservedHeaders[canonical]; ok {
|
||||
return fmt.Errorf("target %d: custom header %q is managed by the proxy and cannot be overridden", idx, key)
|
||||
}
|
||||
if canonical == "Host" {
|
||||
return fmt.Errorf("target %d: use pass_host_header instead of setting Host as a custom header", idx)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func containsCRLF(s string) bool {
|
||||
return strings.ContainsAny(s, "\r\n")
|
||||
}
|
||||
|
||||
func (s *Service) EventMeta() map[string]any {
|
||||
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
|
||||
}
|
||||
@@ -417,6 +619,12 @@ func (s *Service) Copy() *Service {
|
||||
targets := make([]*Target, len(s.Targets))
|
||||
for i, target := range s.Targets {
|
||||
targetCopy := *target
|
||||
if len(target.Options.CustomHeaders) > 0 {
|
||||
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
|
||||
for k, v := range target.Options.CustomHeaders {
|
||||
targetCopy.Options.CustomHeaders[k] = v
|
||||
}
|
||||
}
|
||||
targets[i] = &targetCopy
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -87,6 +88,188 @@ func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "empty target_id")
|
||||
}
|
||||
|
||||
func TestValidateTargetOptions_PathRewrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode PathRewriteMode
|
||||
wantErr string
|
||||
}{
|
||||
{"empty is default", "", ""},
|
||||
{"preserve is valid", PathRewritePreserve, ""},
|
||||
{"unknown rejected", "regex", "unknown path_rewrite mode"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.PathRewrite = tt.mode
|
||||
err := rp.Validate()
|
||||
if tt.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout time.Duration
|
||||
wantErr string
|
||||
}{
|
||||
{"valid 30s", 30 * time.Second, ""},
|
||||
{"valid 2m", 2 * time.Minute, ""},
|
||||
{"zero is fine", 0, ""},
|
||||
{"negative", -1 * time.Second, "must be positive"},
|
||||
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.RequestTimeout = tt.timeout
|
||||
err := rp.Validate()
|
||||
if tt.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetOptions_CustomHeaders(t *testing.T) {
|
||||
t.Run("valid headers", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{
|
||||
"X-Custom": "value",
|
||||
"X-Trace": "abc123",
|
||||
}
|
||||
assert.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("CRLF in key", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Bad\r\nKey": "value"}
|
||||
assert.ErrorContains(t, rp.Validate(), "not a valid HTTP header name")
|
||||
})
|
||||
|
||||
t.Run("CRLF in value", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Good": "bad\nvalue"}
|
||||
assert.ErrorContains(t, rp.Validate(), "invalid characters")
|
||||
})
|
||||
|
||||
t.Run("hop-by-hop header rejected", func(t *testing.T) {
|
||||
for _, h := range []string{"Connection", "Transfer-Encoding", "Keep-Alive", "Upgrade", "Proxy-Connection"} {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
|
||||
assert.ErrorContains(t, rp.Validate(), "hop-by-hop", "header %q should be rejected", h)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reserved header rejected", func(t *testing.T) {
|
||||
for _, h := range []string{"X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Forwarded-Port", "Cookie", "Forwarded", "Content-Length", "Content-Type"} {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
|
||||
assert.ErrorContains(t, rp.Validate(), "managed by the proxy", "header %q should be rejected", h)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Host header rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"Host": "evil.com"}
|
||||
assert.ErrorContains(t, rp.Validate(), "pass_host_header")
|
||||
})
|
||||
|
||||
t.Run("too many headers", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
headers := make(map[string]string, 17)
|
||||
for i := range 17 {
|
||||
headers[fmt.Sprintf("X-H%d", i)] = "v"
|
||||
}
|
||||
rp.Targets[0].Options.CustomHeaders = headers
|
||||
assert.ErrorContains(t, rp.Validate(), "exceeds maximum of 16")
|
||||
})
|
||||
|
||||
t.Run("key too long", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{strings.Repeat("X", 129): "v"}
|
||||
assert.ErrorContains(t, rp.Validate(), "key")
|
||||
assert.ErrorContains(t, rp.Validate(), "exceeds maximum length")
|
||||
})
|
||||
|
||||
t.Run("value too long", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Ok": strings.Repeat("v", 4097)}
|
||||
assert.ErrorContains(t, rp.Validate(), "value exceeds maximum length")
|
||||
})
|
||||
|
||||
t.Run("duplicate canonical keys rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{
|
||||
"x-custom": "a",
|
||||
"X-Custom": "b",
|
||||
}
|
||||
assert.ErrorContains(t, rp.Validate(), "collide")
|
||||
})
|
||||
}
|
||||
|
||||
func TestToProtoMapping_TargetOptions(t *testing.T) {
|
||||
rp := &Service{
|
||||
ID: "svc-1",
|
||||
AccountID: "acc-1",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{
|
||||
TargetId: "peer-1",
|
||||
TargetType: TargetTypePeer,
|
||||
Host: "10.0.0.1",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
Options: TargetOptions{
|
||||
SkipTLSVerify: true,
|
||||
RequestTimeout: 30 * time.Second,
|
||||
PathRewrite: PathRewritePreserve,
|
||||
CustomHeaders: map[string]string{"X-Custom": "val"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||
require.Len(t, pm.Path, 1)
|
||||
|
||||
opts := pm.Path[0].Options
|
||||
require.NotNil(t, opts, "options should be populated")
|
||||
assert.True(t, opts.SkipTlsVerify)
|
||||
assert.Equal(t, proto.PathRewriteMode_PATH_REWRITE_PRESERVE, opts.PathRewrite)
|
||||
assert.Equal(t, map[string]string{"X-Custom": "val"}, opts.CustomHeaders)
|
||||
require.NotNil(t, opts.RequestTimeout)
|
||||
assert.Equal(t, int64(30), opts.RequestTimeout.Seconds)
|
||||
}
|
||||
|
||||
func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
|
||||
rp := &Service{
|
||||
ID: "svc-1",
|
||||
AccountID: "acc-1",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{
|
||||
TargetId: "peer-1",
|
||||
TargetType: TargetTypePeer,
|
||||
Host: "10.0.0.1",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||
require.Len(t, pm.Path, 1)
|
||||
assert.Nil(t, pm.Path[0].Options, "options should be nil when all defaults")
|
||||
}
|
||||
|
||||
func TestIsDefaultPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
scheme string
|
||||
|
||||
@@ -168,7 +168,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
@@ -203,6 +203,16 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
|
||||
return Create(s, func() *nbgrpc.PKCEVerifierStore {
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create PKCE verifier store: %v", err)
|
||||
}
|
||||
return pkceStore
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
||||
return Create(s, func() accesslogs.Manager {
|
||||
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
||||
|
||||
@@ -28,9 +28,13 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
||||
// It is used for backward compatibility now.
|
||||
const ManagementLegacyPort = 33073
|
||||
const (
|
||||
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
||||
// It is used for backward compatibility now.
|
||||
ManagementLegacyPort = 33073
|
||||
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
|
||||
DefaultSelfHostedDomain = "netbird.selfhosted"
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
Start(ctx context.Context) error
|
||||
@@ -58,6 +62,7 @@ type BaseServer struct {
|
||||
mgmtMetricsPort int
|
||||
mgmtPort int
|
||||
disableLegacyManagementPort bool
|
||||
autoResolveDomains bool
|
||||
|
||||
proxyAuthClose func()
|
||||
|
||||
@@ -81,6 +86,7 @@ type Config struct {
|
||||
DisableMetrics bool
|
||||
DisableGeoliteUpdate bool
|
||||
UserDeleteFromIDPEnabled bool
|
||||
AutoResolveDomains bool
|
||||
}
|
||||
|
||||
// NewServer initializes and configures a new Server instance
|
||||
@@ -96,6 +102,7 @@ func NewServer(cfg *Config) *BaseServer {
|
||||
mgmtPort: cfg.MgmtPort,
|
||||
disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
|
||||
mgmtMetricsPort: cfg.MgmtMetricsPort,
|
||||
autoResolveDomains: cfg.AutoResolveDomains,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,6 +116,10 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
s.cancel = cancel
|
||||
s.errCh = make(chan error, 4)
|
||||
|
||||
if s.autoResolveDomains {
|
||||
s.resolveDomains(srvCtx)
|
||||
}
|
||||
|
||||
s.PeersManager()
|
||||
s.GeoLocationManager()
|
||||
|
||||
@@ -237,7 +248,6 @@ func (s *BaseServer) Stop() error {
|
||||
_ = s.certManager.Listener().Close()
|
||||
}
|
||||
s.GRPCServer().Stop()
|
||||
s.ReverseProxyGRPCServer().Close()
|
||||
if s.proxyAuthClose != nil {
|
||||
s.proxyAuthClose()
|
||||
s.proxyAuthClose = nil
|
||||
@@ -381,6 +391,60 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
|
||||
}()
|
||||
}
|
||||
|
||||
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
|
||||
// Fresh installs use the default self-hosted domain, while existing installs reuse the
|
||||
// persisted account domain to keep addressing stable across config changes.
|
||||
func (s *BaseServer) resolveDomains(ctx context.Context) {
|
||||
st := s.Store()
|
||||
|
||||
setDefault := func(logMsg string, args ...any) {
|
||||
if logMsg != "" {
|
||||
log.WithContext(ctx).Warnf(logMsg, args...)
|
||||
}
|
||||
s.dnsDomain = DefaultSelfHostedDomain
|
||||
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
|
||||
}
|
||||
|
||||
accountsCount, err := st.GetAccountsCounter(ctx)
|
||||
if err != nil {
|
||||
setDefault("resolve domains: failed to read accounts counter: %v; using default domain %q", err, DefaultSelfHostedDomain)
|
||||
return
|
||||
}
|
||||
|
||||
if accountsCount == 0 {
|
||||
s.dnsDomain = DefaultSelfHostedDomain
|
||||
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
|
||||
log.WithContext(ctx).Infof("resolve domains: fresh install detected, using default domain %q", DefaultSelfHostedDomain)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, err := st.GetAnyAccountID(ctx)
|
||||
if err != nil {
|
||||
setDefault("resolve domains: failed to get existing account ID: %v; using default domain %q", err, DefaultSelfHostedDomain)
|
||||
return
|
||||
}
|
||||
|
||||
if accountID == "" {
|
||||
setDefault("resolve domains: empty account ID returned for existing accounts; using default domain %q", DefaultSelfHostedDomain)
|
||||
return
|
||||
}
|
||||
|
||||
domain, _, err := st.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
setDefault("resolve domains: failed to get account domain for account %q: %v; using default domain %q", accountID, err, DefaultSelfHostedDomain)
|
||||
return
|
||||
}
|
||||
|
||||
if domain == "" {
|
||||
setDefault("resolve domains: account %q has empty domain; using default domain %q", accountID, DefaultSelfHostedDomain)
|
||||
return
|
||||
}
|
||||
|
||||
s.dnsDomain = domain
|
||||
s.mgmtSingleAccModeDomain = domain
|
||||
log.WithContext(ctx).Infof("resolve domains: using persisted account domain %q", domain)
|
||||
}
|
||||
|
||||
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
|
||||
installationID := store.GetInstallationID()
|
||||
if installationID != "" {
|
||||
|
||||
63
management/internals/server/server_resolve_domains_test.go
Normal file
63
management/internals/server/server_resolve_domains_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), nil)
|
||||
|
||||
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||
Inject[store.Store](srv, mockStore)
|
||||
|
||||
srv.resolveDomains(context.Background())
|
||||
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
|
||||
}
|
||||
|
||||
func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(1), nil)
|
||||
mockStore.EXPECT().GetAnyAccountID(gomock.Any()).Return("acc-1", nil)
|
||||
mockStore.EXPECT().GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "acc-1").Return("vpn.mycompany.com", "", nil)
|
||||
|
||||
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||
Inject[store.Store](srv, mockStore)
|
||||
|
||||
srv.resolveDomains(context.Background())
|
||||
|
||||
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
|
||||
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
|
||||
}
|
||||
|
||||
func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), errors.New("db failed"))
|
||||
|
||||
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||
Inject[store.Store](srv, mockStore)
|
||||
|
||||
srv.resolveDomains(context.Background())
|
||||
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
|
||||
}
|
||||
61
management/internals/shared/grpc/pkce_verifier.go
Normal file
61
management/internals/shared/grpc/pkce_verifier.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/lib/v4/cache"
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
)
|
||||
|
||||
// PKCEVerifierStore manages PKCE verifiers for OAuth flows.
|
||||
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
||||
type PKCEVerifierStore struct {
|
||||
cache *cache.Cache[string]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection
|
||||
func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) {
|
||||
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cache store: %w", err)
|
||||
}
|
||||
|
||||
return &PKCEVerifierStore{
|
||||
cache: cache.New[string](cacheStore),
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Store saves a PKCE verifier associated with an OAuth state parameter.
|
||||
// The verifier is stored with the specified TTL and will be automatically deleted after expiration.
|
||||
func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error {
|
||||
if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil {
|
||||
return fmt.Errorf("failed to store PKCE verifier: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAndDelete retrieves and removes a PKCE verifier for the given state.
|
||||
// Returns the verifier and true if found, or empty string and false if not found.
|
||||
// This enforces single-use semantics for PKCE verifiers.
|
||||
func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) {
|
||||
verifier, err := s.cache.Get(s.ctx, state)
|
||||
if err != nil {
|
||||
log.Debugf("PKCE verifier not found for state")
|
||||
return "", false
|
||||
}
|
||||
|
||||
if err := s.cache.Delete(s.ctx, state); err != nil {
|
||||
log.Warnf("Failed to delete PKCE verifier for state: %v", err)
|
||||
}
|
||||
|
||||
return verifier, true
|
||||
}
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -83,20 +82,12 @@ type ProxyServiceServer struct {
|
||||
// OIDC configuration for proxy authentication
|
||||
oidcConfig ProxyOIDCConfig
|
||||
|
||||
// TODO: use database to store these instead?
|
||||
// pkceVerifiers stores PKCE code verifiers keyed by OAuth state.
|
||||
// Entries expire after pkceVerifierTTL to prevent unbounded growth.
|
||||
pkceVerifiers sync.Map
|
||||
pkceCleanupCancel context.CancelFunc
|
||||
// Store for PKCE verifiers
|
||||
pkceVerifierStore *PKCEVerifierStore
|
||||
}
|
||||
|
||||
const pkceVerifierTTL = 10 * time.Minute
|
||||
|
||||
type pkceEntry struct {
|
||||
verifier string
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// proxyConnection represents a connected proxy
|
||||
type proxyConnection struct {
|
||||
proxyID string
|
||||
@@ -108,42 +99,21 @@ type proxyConnection struct {
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
ctx := context.Background()
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
oidcConfig: oidcConfig,
|
||||
tokenStore: tokenStore,
|
||||
pkceVerifierStore: pkceStore,
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
pkceCleanupCancel: cancel,
|
||||
}
|
||||
go s.cleanupPKCEVerifiers(ctx)
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
return s
|
||||
}
|
||||
|
||||
// cleanupPKCEVerifiers periodically removes expired PKCE verifiers.
|
||||
func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
|
||||
ticker := time.NewTicker(pkceVerifierTTL)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.pkceVerifiers.Range(func(key, value any) bool {
|
||||
if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL {
|
||||
s.pkceVerifiers.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
|
||||
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
@@ -160,11 +130,6 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops background goroutines.
|
||||
func (s *ProxyServiceServer) Close() {
|
||||
s.pkceCleanupCancel()
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
|
||||
s.serviceManager = manager
|
||||
}
|
||||
@@ -177,11 +142,7 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
|
||||
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
||||
ctx := stream.Context()
|
||||
|
||||
peerInfo := ""
|
||||
if p, ok := peer.FromContext(ctx); ok {
|
||||
peerInfo = p.Addr.String()
|
||||
}
|
||||
|
||||
peerInfo := PeerIPFromContext(ctx)
|
||||
log.Infof("New proxy connection from %s", peerInfo)
|
||||
|
||||
proxyID := req.GetProxyId()
|
||||
@@ -795,7 +756,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
||||
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
|
||||
|
||||
codeVerifier := oauth2.GenerateVerifier()
|
||||
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()})
|
||||
if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err)
|
||||
return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err)
|
||||
}
|
||||
|
||||
return &proto.GetOIDCURLResponse{
|
||||
Url: (&oauth2.Config{
|
||||
@@ -832,18 +796,10 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
|
||||
// ValidateState validates the state parameter from an OAuth callback.
|
||||
// Returns the original redirect URL if valid, or an error if invalid.
|
||||
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
|
||||
v, ok := s.pkceVerifiers.LoadAndDelete(state)
|
||||
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
|
||||
if !ok {
|
||||
return "", "", errors.New("no verifier for state")
|
||||
}
|
||||
entry, ok := v.(pkceEntry)
|
||||
if !ok {
|
||||
return "", "", errors.New("invalid verifier for state")
|
||||
}
|
||||
if time.Since(entry.createdAt) > pkceVerifierTTL {
|
||||
return "", "", errors.New("PKCE verifier expired")
|
||||
}
|
||||
verifier = entry.verifier
|
||||
|
||||
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
|
||||
parts := strings.Split(state, "|")
|
||||
|
||||
@@ -107,7 +107,7 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter
|
||||
}
|
||||
|
||||
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
||||
clientIP := peerIPFromContext(ctx)
|
||||
clientIP := PeerIPFromContext(ctx)
|
||||
|
||||
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
|
||||
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
|
||||
|
||||
@@ -115,9 +115,9 @@ func (l *authFailureLimiter) stop() {
|
||||
l.cancel()
|
||||
}
|
||||
|
||||
// peerIPFromContext extracts the client IP from the gRPC context.
|
||||
// PeerIPFromContext extracts the client IP from the gRPC context.
|
||||
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
|
||||
func peerIPFromContext(ctx context.Context) clientIP {
|
||||
func PeerIPFromContext(ctx context.Context) string {
|
||||
if addr, ok := realip.FromContext(ctx); ok {
|
||||
return addr.String()
|
||||
}
|
||||
|
||||
@@ -5,11 +5,10 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"sync"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -94,11 +93,16 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
ctx := context.Background()
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
tokenStore: tokenStore,
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
@@ -151,11 +155,16 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
ctx := context.Background()
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
tokenStore: tokenStore,
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
@@ -185,11 +194,16 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
ctx := context.Background()
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
tokenStore: tokenStore,
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
@@ -241,10 +255,15 @@ func generateState(s *ProxyServiceServer, redirectURL string) string {
|
||||
}
|
||||
|
||||
func TestOAuthState_NeverTheSame(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
|
||||
redirectURL := "https://app.example.com/callback"
|
||||
@@ -265,31 +284,43 @@ func TestOAuthState_NeverTheSame(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
|
||||
// Old format had only 2 parts: base64(url)|hmac
|
||||
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||
err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err := s.ValidateState("base64url|hmac")
|
||||
_, _, err = s.ValidateState("base64url|hmac")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid state format")
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
|
||||
// Store with tampered HMAC
|
||||
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||
err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
||||
_, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid state signature")
|
||||
}
|
||||
|
||||
@@ -330,13 +330,12 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
|
||||
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
log.WithContext(ctx).Debugf("Sync took %s", time.Since(reqStart))
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
@@ -743,13 +742,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
|
||||
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||
|
||||
defer func() {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart))
|
||||
}()
|
||||
|
||||
if loginReq.GetMeta() == nil {
|
||||
msg := status.Errorf(codes.FailedPrecondition,
|
||||
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
@@ -799,6 +791,11 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
}
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
log.WithContext(ctx).Debugf("Login took %s", time.Since(reqStart))
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
WgPubKey: key.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
|
||||
@@ -41,7 +41,10 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
|
||||
@@ -1379,9 +1379,10 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||
// This section is mostly related to self-hosted installations.
|
||||
// We override incoming domain claims to group users under a single account.
|
||||
userAuth.Domain = am.singleAccountModeDomain
|
||||
userAuth.DomainCategory = types.PrivateCategory
|
||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||
err := am.updateUserAuthWithSingleMode(ctx, &userAuth)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
|
||||
@@ -1414,6 +1415,35 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account
|
||||
func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error {
|
||||
userAuth.DomainCategory = types.PrivateCategory
|
||||
userAuth.Domain = am.singleAccountModeDomain
|
||||
|
||||
accountID, err := am.Store.GetAnyAccountID(ctx)
|
||||
if err != nil {
|
||||
if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound {
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
|
||||
return nil
|
||||
}
|
||||
|
||||
if accountID == "" {
|
||||
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
|
||||
return nil
|
||||
}
|
||||
|
||||
domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userAuth.Domain = domain
|
||||
|
||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/prometheus/client_golang/prometheus/push"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -3132,7 +3133,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -3966,3 +3967,116 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
|
||||
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserAuthWithSingleMode(t *testing.T) {
|
||||
t.Run("sets defaults and overrides domain from store", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetAnyAccountID(gomock.Any()).
|
||||
Return("account-1", nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
|
||||
Return("real-domain.com", "private", nil)
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: mockStore,
|
||||
singleAccountModeDomain: "fallback.com",
|
||||
}
|
||||
|
||||
userAuth := &auth.UserAuth{}
|
||||
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "real-domain.com", userAuth.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||
})
|
||||
|
||||
t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetAnyAccountID(gomock.Any()).
|
||||
Return("", nil)
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: mockStore,
|
||||
singleAccountModeDomain: "fallback.com",
|
||||
}
|
||||
|
||||
userAuth := &auth.UserAuth{}
|
||||
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "fallback.com", userAuth.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||
})
|
||||
|
||||
t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetAnyAccountID(gomock.Any()).
|
||||
Return("", status.Errorf(status.NotFound, "no accounts"))
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: mockStore,
|
||||
singleAccountModeDomain: "fallback.com",
|
||||
}
|
||||
|
||||
userAuth := &auth.UserAuth{}
|
||||
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "fallback.com", userAuth.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||
})
|
||||
|
||||
t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetAnyAccountID(gomock.Any()).
|
||||
Return("", status.Errorf(status.Internal, "db down"))
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: mockStore,
|
||||
singleAccountModeDomain: "fallback.com",
|
||||
}
|
||||
|
||||
userAuth := &auth.UserAuth{}
|
||||
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "db down")
|
||||
// Defaults should still be set before error path
|
||||
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||
})
|
||||
|
||||
t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetAnyAccountID(gomock.Any()).
|
||||
Return("account-1", nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
|
||||
Return("", "", status.Errorf(status.Internal, "query failed"))
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: mockStore,
|
||||
singleAccountModeDomain: "fallback.com",
|
||||
}
|
||||
|
||||
userAuth := &auth.UserAuth{}
|
||||
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "query failed")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -193,6 +193,9 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
usersManager := users.NewManager(testStore)
|
||||
|
||||
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
||||
@@ -206,6 +209,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(
|
||||
&testAccessLogManager{},
|
||||
tokenStore,
|
||||
pkceStore,
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
|
||||
@@ -98,12 +98,16 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy token store: %v", err)
|
||||
}
|
||||
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create PKCE verifier store: %v", err)
|
||||
}
|
||||
noopMeter := noop.NewMeterProvider().Meter("")
|
||||
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
|
||||
@@ -4977,9 +4977,9 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
var service *rpservice.Service
|
||||
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
|
||||
result := s.db.Preload("Targets").Where("domain = ?", domain).First(&service)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain)
|
||||
@@ -5040,6 +5040,99 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS
|
||||
return serviceList, nil
|
||||
}
|
||||
|
||||
// RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service.
|
||||
func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
|
||||
result := s.db.Model(&rpservice.Service{}).
|
||||
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
|
||||
Update("meta_last_renewed_at", time.Now())
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "renew ephemeral service")
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExpiredEphemeralServices returns ephemeral services whose last renewal exceeds the given TTL.
|
||||
// Only the fields needed for reaping are selected. The limit parameter caps the batch size to
|
||||
// avoid loading too many rows in a single tick. Rows with empty source_peer are excluded to
|
||||
// skip malformed legacy data.
|
||||
func (s *SqlStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) {
|
||||
cutoff := time.Now().Add(-ttl)
|
||||
var services []*rpservice.Service
|
||||
result := s.db.
|
||||
Select("id", "account_id", "source_peer", "domain").
|
||||
Where("source = ? AND source_peer <> '' AND meta_last_renewed_at < ?", rpservice.SourceEphemeral, cutoff).
|
||||
Limit(limit).
|
||||
Find(&services)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get expired ephemeral services: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "get expired ephemeral services")
|
||||
}
|
||||
return services, nil
|
||||
}
|
||||
|
||||
// CountEphemeralServicesByPeer returns the count of ephemeral services for a specific peer.
|
||||
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
|
||||
// The locking is applied via a row-level SELECT ... FOR UPDATE (not on the aggregate) to
|
||||
// stay compatible with Postgres, which disallows FOR UPDATE on COUNT(*).
|
||||
func (s *SqlStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
|
||||
if lockStrength == LockingStrengthNone {
|
||||
var count int64
|
||||
result := s.db.Model(&rpservice.Service{}).
|
||||
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
|
||||
Count(&count)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
|
||||
return 0, status.Errorf(status.Internal, "count ephemeral services")
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
var ids []string
|
||||
result := s.db.Model(&rpservice.Service{}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Select("id").
|
||||
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
|
||||
Pluck("id", &ids)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
|
||||
return 0, status.Errorf(status.Internal, "count ephemeral services")
|
||||
}
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
// EphemeralServiceExists checks if an ephemeral service exists for the given peer and domain.
|
||||
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
|
||||
func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
|
||||
if lockStrength == LockingStrengthNone {
|
||||
var count int64
|
||||
result := s.db.Model(&rpservice.Service{}).
|
||||
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
|
||||
Count(&count)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
|
||||
return false, status.Errorf(status.Internal, "check ephemeral service existence")
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
var id string
|
||||
result := s.db.Model(&rpservice.Service{}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Select("id").
|
||||
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
|
||||
Limit(1).
|
||||
Pluck("id", &id)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
|
||||
return false, status.Errorf(status.Internal, "check ephemeral service existence")
|
||||
}
|
||||
return id != "", nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) {
|
||||
tx := s.db
|
||||
|
||||
|
||||
@@ -257,10 +257,15 @@ type Store interface {
|
||||
UpdateService(ctx context.Context, service *rpservice.Service) error
|
||||
DeleteService(ctx context.Context, accountID, serviceID string) error
|
||||
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
|
||||
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
|
||||
GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error)
|
||||
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
|
||||
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
|
||||
|
||||
RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error
|
||||
GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error)
|
||||
CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error)
|
||||
EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error)
|
||||
|
||||
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
|
||||
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
|
||||
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
|
||||
|
||||
@@ -208,6 +208,21 @@ func (mr *MockStoreMockRecorder) CountAccountsByPrivateDomain(ctx, domain interf
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountsByPrivateDomain", reflect.TypeOf((*MockStore)(nil).CountAccountsByPrivateDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// CountEphemeralServicesByPeer mocks base method.
|
||||
func (m *MockStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountEphemeralServicesByPeer", ctx, lockStrength, accountID, peerID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountEphemeralServicesByPeer indicates an expected call of CountEphemeralServicesByPeer.
|
||||
func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength, accountID, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
|
||||
}
|
||||
|
||||
// CreateAccessLog mocks base method.
|
||||
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -686,6 +701,21 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
|
||||
}
|
||||
|
||||
// EphemeralServiceExists mocks base method.
|
||||
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "EphemeralServiceExists", ctx, lockStrength, accountID, peerID, domain)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// EphemeralServiceExists indicates an expected call of EphemeralServiceExists.
|
||||
func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accountID, peerID, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// ExecuteInTransaction mocks base method.
|
||||
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1362,6 +1392,21 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID)
|
||||
}
|
||||
|
||||
// GetExpiredEphemeralServices mocks base method.
|
||||
func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetExpiredEphemeralServices", ctx, ttl, limit)
|
||||
ret0, _ := ret[0].([]*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetExpiredEphemeralServices indicates an expected call of GetExpiredEphemeralServices.
|
||||
func (mr *MockStoreMockRecorder) GetExpiredEphemeralServices(ctx, ttl, limit interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiredEphemeralServices", reflect.TypeOf((*MockStore)(nil).GetExpiredEphemeralServices), ctx, ttl, limit)
|
||||
}
|
||||
|
||||
// GetGroupByID mocks base method.
|
||||
func (m *MockStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1887,18 +1932,18 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain interface{}) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, accountID, domain)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// GetServiceByID mocks base method.
|
||||
@@ -2401,6 +2446,20 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResourceFromGroup", reflect.TypeOf((*MockStore)(nil).RemoveResourceFromGroup), ctx, accountId, groupID, resourceID)
|
||||
}
|
||||
|
||||
// RenewEphemeralService mocks base method.
|
||||
func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, domain)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RenewEphemeralService indicates an expected call of RenewEphemeralService.
|
||||
func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// RevokeProxyAccessToken mocks base method.
|
||||
func (m *MockStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
185
management/server/telemetry/account_aggregator.go
Normal file
185
management/server/telemetry/account_aggregator.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/metric/metricdata"
|
||||
)
|
||||
|
||||
// AccountDurationAggregator uses OpenTelemetry histograms per account to calculate P95
|
||||
// without publishing individual account labels
|
||||
type AccountDurationAggregator struct {
|
||||
mu sync.RWMutex
|
||||
accounts map[string]*accountHistogram
|
||||
meterProvider *sdkmetric.MeterProvider
|
||||
manualReader *sdkmetric.ManualReader
|
||||
|
||||
FlushInterval time.Duration
|
||||
MaxAge time.Duration
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type accountHistogram struct {
|
||||
histogram metric.Int64Histogram
|
||||
lastUpdate time.Time
|
||||
}
|
||||
|
||||
// NewAccountDurationAggregator creates aggregator using OTel histograms
|
||||
func NewAccountDurationAggregator(ctx context.Context, flushInterval, maxAge time.Duration) *AccountDurationAggregator {
|
||||
manualReader := sdkmetric.NewManualReader(
|
||||
sdkmetric.WithTemporalitySelector(func(kind sdkmetric.InstrumentKind) metricdata.Temporality {
|
||||
return metricdata.DeltaTemporality
|
||||
}),
|
||||
)
|
||||
|
||||
meterProvider := sdkmetric.NewMeterProvider(
|
||||
sdkmetric.WithReader(manualReader),
|
||||
)
|
||||
|
||||
return &AccountDurationAggregator{
|
||||
accounts: make(map[string]*accountHistogram),
|
||||
meterProvider: meterProvider,
|
||||
manualReader: manualReader,
|
||||
FlushInterval: flushInterval,
|
||||
MaxAge: maxAge,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// Record adds a duration for an account using OTel histogram
|
||||
func (a *AccountDurationAggregator) Record(accountID string, duration time.Duration) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
accHist, exists := a.accounts[accountID]
|
||||
if !exists {
|
||||
meter := a.meterProvider.Meter("account-aggregator")
|
||||
histogram, err := meter.Int64Histogram(
|
||||
"sync_duration_per_account",
|
||||
metric.WithUnit("milliseconds"),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
accHist = &accountHistogram{
|
||||
histogram: histogram,
|
||||
}
|
||||
a.accounts[accountID] = accHist
|
||||
}
|
||||
|
||||
accHist.histogram.Record(a.ctx, duration.Milliseconds(),
|
||||
metric.WithAttributes(attribute.String("account_id", accountID)))
|
||||
accHist.lastUpdate = time.Now()
|
||||
}
|
||||
|
||||
// FlushAndGetP95s extracts P95 from each account's histogram
|
||||
func (a *AccountDurationAggregator) FlushAndGetP95s() []int64 {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
var rm metricdata.ResourceMetrics
|
||||
err := a.manualReader.Collect(a.ctx, &rm)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
p95s := make([]int64, 0, len(a.accounts))
|
||||
|
||||
for _, scopeMetrics := range rm.ScopeMetrics {
|
||||
for _, metric := range scopeMetrics.Metrics {
|
||||
histogramData, ok := metric.Data.(metricdata.Histogram[int64])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, dataPoint := range histogramData.DataPoints {
|
||||
a.processDataPoint(dataPoint, now, &p95s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a.cleanupStaleAccounts(now)
|
||||
|
||||
return p95s
|
||||
}
|
||||
|
||||
// processDataPoint extracts P95 from a single histogram data point
|
||||
func (a *AccountDurationAggregator) processDataPoint(dataPoint metricdata.HistogramDataPoint[int64], now time.Time, p95s *[]int64) {
|
||||
accountID := extractAccountID(dataPoint)
|
||||
if accountID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if p95 := calculateP95FromHistogram(dataPoint); p95 > 0 {
|
||||
*p95s = append(*p95s, p95)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleAccounts removes accounts that haven't been updated recently
|
||||
func (a *AccountDurationAggregator) cleanupStaleAccounts(now time.Time) {
|
||||
for accountID := range a.accounts {
|
||||
if a.isStaleAccount(accountID, now) {
|
||||
delete(a.accounts, accountID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractAccountID retrieves the account_id from histogram data point attributes
|
||||
func extractAccountID(dp metricdata.HistogramDataPoint[int64]) string {
|
||||
for _, attr := range dp.Attributes.ToSlice() {
|
||||
if attr.Key == "account_id" {
|
||||
return attr.Value.AsString()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// isStaleAccount checks if an account hasn't been updated recently
|
||||
func (a *AccountDurationAggregator) isStaleAccount(accountID string, now time.Time) bool {
|
||||
accHist, exists := a.accounts[accountID]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
return now.Sub(accHist.lastUpdate) > a.MaxAge
|
||||
}
|
||||
|
||||
// calculateP95FromHistogram computes P95 from OTel histogram data
|
||||
func calculateP95FromHistogram(dp metricdata.HistogramDataPoint[int64]) int64 {
|
||||
if dp.Count == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
targetCount := uint64(math.Ceil(float64(dp.Count) * 0.95))
|
||||
if targetCount == 0 {
|
||||
targetCount = 1
|
||||
}
|
||||
var cumulativeCount uint64
|
||||
|
||||
for i, bucketCount := range dp.BucketCounts {
|
||||
cumulativeCount += bucketCount
|
||||
if cumulativeCount >= targetCount {
|
||||
if i < len(dp.Bounds) {
|
||||
return int64(dp.Bounds[i])
|
||||
}
|
||||
if maxVal, defined := dp.Max.Value(); defined {
|
||||
return maxVal
|
||||
}
|
||||
return dp.Sum / int64(dp.Count)
|
||||
}
|
||||
}
|
||||
|
||||
return dp.Sum / int64(dp.Count)
|
||||
}
|
||||
|
||||
// Shutdown cleans up resources
|
||||
func (a *AccountDurationAggregator) Shutdown() error {
|
||||
return a.meterProvider.Shutdown(a.ctx)
|
||||
}
|
||||
219
management/server/telemetry/account_aggregator_test.go
Normal file
219
management/server/telemetry/account_aggregator_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDeltaTemporality_P95ReflectsCurrentWindow(t *testing.T) {
|
||||
// Verify that with delta temporality, each flush window only reflects
|
||||
// recordings since the last flush — not all-time data.
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
// Window 1: Record 100 slow requests (500ms each)
|
||||
for range 100 {
|
||||
agg.Record("account-A", 500*time.Millisecond)
|
||||
}
|
||||
|
||||
p95sWindow1 := agg.FlushAndGetP95s()
|
||||
require.Len(t, p95sWindow1, 1, "should have P95 for one account")
|
||||
firstP95 := p95sWindow1[0]
|
||||
assert.GreaterOrEqual(t, firstP95, int64(200),
|
||||
"first window P95 should reflect the 500ms recordings")
|
||||
|
||||
// Window 2: Record 100 FAST requests (10ms each)
|
||||
for range 100 {
|
||||
agg.Record("account-A", 10*time.Millisecond)
|
||||
}
|
||||
|
||||
p95sWindow2 := agg.FlushAndGetP95s()
|
||||
require.Len(t, p95sWindow2, 1, "should have P95 for one account")
|
||||
secondP95 := p95sWindow2[0]
|
||||
|
||||
// With delta temporality the P95 should drop significantly because
|
||||
// the first window's slow recordings are no longer included.
|
||||
assert.Less(t, secondP95, firstP95,
|
||||
"second window P95 should be lower than first — delta temporality "+
|
||||
"ensures each window only reflects recent recordings")
|
||||
}
|
||||
|
||||
func TestEqualWeightPerAccount(t *testing.T) {
|
||||
// Verify that each account contributes exactly one P95 value,
|
||||
// regardless of how many requests it made.
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
// Account A: 10,000 requests at 500ms (noisy customer)
|
||||
for range 10000 {
|
||||
agg.Record("account-A", 500*time.Millisecond)
|
||||
}
|
||||
|
||||
// Accounts B, C, D: 10 requests each at 50ms (normal customers)
|
||||
for _, id := range []string{"account-B", "account-C", "account-D"} {
|
||||
for range 10 {
|
||||
agg.Record(id, 50*time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
p95s := agg.FlushAndGetP95s()
|
||||
|
||||
// Should get exactly 4 P95 values — one per account
|
||||
assert.Len(t, p95s, 4, "each account should contribute exactly one P95")
|
||||
}
|
||||
|
||||
func TestStaleAccountEviction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// Use a very short MaxAge so we can test staleness
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
agg.Record("account-A", 100*time.Millisecond)
|
||||
agg.Record("account-B", 200*time.Millisecond)
|
||||
|
||||
// Both accounts should appear
|
||||
p95s := agg.FlushAndGetP95s()
|
||||
assert.Len(t, p95s, 2, "both accounts should have P95 values")
|
||||
|
||||
// Wait for account-A to become stale, then only update account-B
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
agg.Record("account-B", 200*time.Millisecond)
|
||||
|
||||
p95s = agg.FlushAndGetP95s()
|
||||
assert.Len(t, p95s, 1, "both accounts should have P95 values")
|
||||
|
||||
// account-A should have been evicted from the accounts map
|
||||
agg.mu.RLock()
|
||||
_, accountAExists := agg.accounts["account-A"]
|
||||
_, accountBExists := agg.accounts["account-B"]
|
||||
agg.mu.RUnlock()
|
||||
|
||||
assert.False(t, accountAExists, "stale account-A should be evicted from map")
|
||||
assert.True(t, accountBExists, "active account-B should remain in map")
|
||||
}
|
||||
|
||||
func TestStaleAccountEviction_DoesNotReappear(t *testing.T) {
|
||||
// Verify that with delta temporality, an evicted stale account does not
|
||||
// reappear in subsequent flushes.
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
agg.Record("account-stale", 100*time.Millisecond)
|
||||
|
||||
// Wait for it to become stale
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// First flush: should detect staleness and evict
|
||||
_ = agg.FlushAndGetP95s()
|
||||
|
||||
agg.mu.RLock()
|
||||
_, exists := agg.accounts["account-stale"]
|
||||
agg.mu.RUnlock()
|
||||
assert.False(t, exists, "account should be evicted after first flush")
|
||||
|
||||
// Second flush: with delta temporality, the stale account should NOT reappear
|
||||
p95sSecond := agg.FlushAndGetP95s()
|
||||
assert.Empty(t, p95sSecond,
|
||||
"evicted account should not reappear in subsequent flushes with delta temporality")
|
||||
}
|
||||
|
||||
func TestP95Calculation_SingleSample(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
agg.Record("account-A", 150*time.Millisecond)
|
||||
|
||||
p95s := agg.FlushAndGetP95s()
|
||||
require.Len(t, p95s, 1)
|
||||
// With a single sample, P95 should be the bucket bound containing 150ms
|
||||
assert.Greater(t, p95s[0], int64(0), "P95 of a single sample should be positive")
|
||||
}
|
||||
|
||||
func TestP95Calculation_AllSameValue(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
// All samples are 100ms — P95 should be the bucket bound containing 100ms
|
||||
for range 100 {
|
||||
agg.Record("account-A", 100*time.Millisecond)
|
||||
}
|
||||
|
||||
p95s := agg.FlushAndGetP95s()
|
||||
require.Len(t, p95s, 1)
|
||||
assert.Greater(t, p95s[0], int64(0))
|
||||
}
|
||||
|
||||
func TestMultipleAccounts_IndependentP95s(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
// Account A: all fast (10ms)
|
||||
for range 100 {
|
||||
agg.Record("account-fast", 10*time.Millisecond)
|
||||
}
|
||||
|
||||
// Account B: all slow (5000ms)
|
||||
for range 100 {
|
||||
agg.Record("account-slow", 5000*time.Millisecond)
|
||||
}
|
||||
|
||||
p95s := agg.FlushAndGetP95s()
|
||||
require.Len(t, p95s, 2, "should have two P95 values")
|
||||
|
||||
// Find min and max — they should differ significantly
|
||||
minP95 := p95s[0]
|
||||
maxP95 := p95s[1]
|
||||
if minP95 > maxP95 {
|
||||
minP95, maxP95 = maxP95, minP95
|
||||
}
|
||||
|
||||
assert.Less(t, minP95, int64(1000),
|
||||
"fast account P95 should be well under 1000ms")
|
||||
assert.Greater(t, maxP95, int64(1000),
|
||||
"slow account P95 should be well over 1000ms")
|
||||
}
|
||||
@@ -13,18 +13,24 @@ const HighLatencyThreshold = time.Second * 7
|
||||
|
||||
// GRPCMetrics are gRPC server metrics
|
||||
type GRPCMetrics struct {
|
||||
meter metric.Meter
|
||||
syncRequestsCounter metric.Int64Counter
|
||||
syncRequestsBlockedCounter metric.Int64Counter
|
||||
loginRequestsCounter metric.Int64Counter
|
||||
loginRequestsBlockedCounter metric.Int64Counter
|
||||
loginRequestHighLatencyCounter metric.Int64Counter
|
||||
getKeyRequestsCounter metric.Int64Counter
|
||||
activeStreamsGauge metric.Int64ObservableGauge
|
||||
syncRequestDuration metric.Int64Histogram
|
||||
loginRequestDuration metric.Int64Histogram
|
||||
channelQueueLength metric.Int64Histogram
|
||||
ctx context.Context
|
||||
meter metric.Meter
|
||||
syncRequestsCounter metric.Int64Counter
|
||||
syncRequestsBlockedCounter metric.Int64Counter
|
||||
loginRequestsCounter metric.Int64Counter
|
||||
loginRequestsBlockedCounter metric.Int64Counter
|
||||
loginRequestHighLatencyCounter metric.Int64Counter
|
||||
getKeyRequestsCounter metric.Int64Counter
|
||||
activeStreamsGauge metric.Int64ObservableGauge
|
||||
syncRequestDuration metric.Int64Histogram
|
||||
syncRequestDurationP95ByAccount metric.Int64Histogram
|
||||
loginRequestDuration metric.Int64Histogram
|
||||
loginRequestDurationP95ByAccount metric.Int64Histogram
|
||||
channelQueueLength metric.Int64Histogram
|
||||
ctx context.Context
|
||||
|
||||
// Per-account aggregation
|
||||
syncDurationAggregator *AccountDurationAggregator
|
||||
loginDurationAggregator *AccountDurationAggregator
|
||||
}
|
||||
|
||||
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
|
||||
@@ -93,6 +99,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.sync.request.duration.p95.by.account.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("P95 duration of sync requests aggregated per account - each data point represents one account's P95"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginRequestDuration, err := meter.Int64Histogram("management.grpc.login.request.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration of the login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
|
||||
@@ -101,6 +115,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.login.request.duration.p95.by.account.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("P95 duration of login requests aggregated per account - each data point represents one account's P95"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time
|
||||
// Then we should be able to extract min, manx, mean and the percentiles.
|
||||
// TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100)
|
||||
@@ -113,20 +135,32 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GRPCMetrics{
|
||||
meter: meter,
|
||||
syncRequestsCounter: syncRequestsCounter,
|
||||
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
|
||||
loginRequestsCounter: loginRequestsCounter,
|
||||
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
|
||||
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
|
||||
getKeyRequestsCounter: getKeyRequestsCounter,
|
||||
activeStreamsGauge: activeStreamsGauge,
|
||||
syncRequestDuration: syncRequestDuration,
|
||||
loginRequestDuration: loginRequestDuration,
|
||||
channelQueueLength: channelQueue,
|
||||
ctx: ctx,
|
||||
}, err
|
||||
syncDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
|
||||
loginDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
|
||||
|
||||
grpcMetrics := &GRPCMetrics{
|
||||
meter: meter,
|
||||
syncRequestsCounter: syncRequestsCounter,
|
||||
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
|
||||
loginRequestsCounter: loginRequestsCounter,
|
||||
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
|
||||
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
|
||||
getKeyRequestsCounter: getKeyRequestsCounter,
|
||||
activeStreamsGauge: activeStreamsGauge,
|
||||
syncRequestDuration: syncRequestDuration,
|
||||
syncRequestDurationP95ByAccount: syncRequestDurationP95ByAccount,
|
||||
loginRequestDuration: loginRequestDuration,
|
||||
loginRequestDurationP95ByAccount: loginRequestDurationP95ByAccount,
|
||||
channelQueueLength: channelQueue,
|
||||
ctx: ctx,
|
||||
syncDurationAggregator: syncDurationAggregator,
|
||||
loginDurationAggregator: loginDurationAggregator,
|
||||
}
|
||||
|
||||
go grpcMetrics.startSyncP95Flusher()
|
||||
go grpcMetrics.startLoginP95Flusher()
|
||||
|
||||
return grpcMetrics, err
|
||||
}
|
||||
|
||||
// CountSyncRequest counts the number of gRPC sync requests coming to the gRPC API
|
||||
@@ -157,6 +191,9 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() {
|
||||
// CountLoginRequestDuration counts the duration of the login gRPC requests
|
||||
func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) {
|
||||
grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
||||
|
||||
grpcMetrics.loginDurationAggregator.Record(accountID, duration)
|
||||
|
||||
if duration > HighLatencyThreshold {
|
||||
grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
|
||||
}
|
||||
@@ -165,6 +202,44 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
|
||||
// CountSyncRequestDuration counts the duration of the sync gRPC requests
|
||||
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
|
||||
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
||||
|
||||
grpcMetrics.syncDurationAggregator.Record(accountID, duration)
|
||||
}
|
||||
|
||||
// startSyncP95Flusher periodically flushes per-account sync P95 values to the histogram
|
||||
func (grpcMetrics *GRPCMetrics) startSyncP95Flusher() {
|
||||
ticker := time.NewTicker(grpcMetrics.syncDurationAggregator.FlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-grpcMetrics.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p95s := grpcMetrics.syncDurationAggregator.FlushAndGetP95s()
|
||||
for _, p95 := range p95s {
|
||||
grpcMetrics.syncRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startLoginP95Flusher periodically flushes per-account login P95 values to the histogram
|
||||
func (grpcMetrics *GRPCMetrics) startLoginP95Flusher() {
|
||||
ticker := time.NewTicker(grpcMetrics.loginDurationAggregator.FlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-grpcMetrics.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p95s := grpcMetrics.loginDurationAggregator.FlushAndGetP95s()
|
||||
for _, p95 := range p95s {
|
||||
grpcMetrics.loginRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
|
||||
|
||||
@@ -28,10 +28,12 @@ func BenchmarkServeHTTP(b *testing.B) {
|
||||
ID: rand.Text(),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: "app.example.com",
|
||||
Paths: map[string]*url.URL{
|
||||
Paths: map[string]*proxy.PathTarget{
|
||||
"/": {
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -67,10 +69,12 @@ func BenchmarkServeHTTPHostCount(b *testing.B) {
|
||||
ID: id,
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: host,
|
||||
Paths: map[string]*url.URL{
|
||||
Paths: map[string]*proxy.PathTarget{
|
||||
"/": {
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -100,15 +104,17 @@ func BenchmarkServeHTTPPathCount(b *testing.B) {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
paths := make(map[string]*url.URL, pathCount)
|
||||
paths := make(map[string]*proxy.PathTarget, pathCount)
|
||||
for i := range pathCount {
|
||||
path := "/" + rand.Text()
|
||||
if int64(i) == targetIndex.Int64() {
|
||||
target = path
|
||||
}
|
||||
paths[path] = &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
|
||||
paths[path] = &proxy.PathTarget{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
|
||||
},
|
||||
}
|
||||
}
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
|
||||
@@ -80,14 +80,30 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
capturedData.SetAccountId(result.accountID)
|
||||
}
|
||||
|
||||
pt := result.target
|
||||
|
||||
if pt.SkipTLSVerify {
|
||||
ctx = roundtrip.WithSkipTLSVerify(ctx)
|
||||
}
|
||||
if pt.RequestTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
rewriteMatchedPath := result.matchedPath
|
||||
if pt.PathRewrite == PathRewritePreserve {
|
||||
rewriteMatchedPath = ""
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
|
||||
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders),
|
||||
Transport: p.transport,
|
||||
FlushInterval: -1,
|
||||
ErrorHandler: proxyErrorHandler,
|
||||
}
|
||||
if result.rewriteRedirects {
|
||||
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose
|
||||
rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose
|
||||
}
|
||||
rp.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
@@ -97,16 +113,22 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// forwarding headers and stripping proxy authentication credentials.
|
||||
// When passHostHeader is true, the original client Host header is preserved
|
||||
// instead of being rewritten to the backend's address.
|
||||
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool) func(r *httputil.ProxyRequest) {
|
||||
// The pathRewrite parameter controls how the request path is transformed.
|
||||
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string) func(r *httputil.ProxyRequest) {
|
||||
return func(r *httputil.ProxyRequest) {
|
||||
// Strip the matched path prefix from the incoming request path before
|
||||
// SetURL joins it with the target's base path, avoiding path duplication.
|
||||
if matchedPath != "" && matchedPath != "/" {
|
||||
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
|
||||
if r.Out.URL.Path == "" {
|
||||
r.Out.URL.Path = "/"
|
||||
switch pathRewrite {
|
||||
case PathRewritePreserve:
|
||||
// Keep the full original request path as-is.
|
||||
default:
|
||||
if matchedPath != "" && matchedPath != "/" {
|
||||
// Strip the matched path prefix from the incoming request path before
|
||||
// SetURL joins it with the target's base path, avoiding path duplication.
|
||||
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
|
||||
if r.Out.URL.Path == "" {
|
||||
r.Out.URL.Path = "/"
|
||||
}
|
||||
r.Out.URL.RawPath = ""
|
||||
}
|
||||
r.Out.URL.RawPath = ""
|
||||
}
|
||||
|
||||
r.SetURL(target)
|
||||
@@ -116,6 +138,10 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
|
||||
r.Out.Host = target.Host
|
||||
}
|
||||
|
||||
for k, v := range customHeaders {
|
||||
r.Out.Header.Set(k, v)
|
||||
}
|
||||
|
||||
clientIP := extractClientIP(r.In.RemoteAddr)
|
||||
|
||||
if IsTrustedProxy(clientIP, p.trustedProxies) {
|
||||
|
||||
@@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
|
||||
t.Run("rewrites host to backend by default", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "", true)
|
||||
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
@@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
@@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("auto detects https from TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
@@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("auto detects http without TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "https"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
// No TLS, but forced to https
|
||||
|
||||
@@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
|
||||
t.Run("forced http proto", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "http"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
@@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
t.Run("strips nb_session cookie", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
@@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
||||
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
t.Run("strips session_token query parameter", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
|
||||
@@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
|
||||
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080/app")
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
|
||||
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||
rewrite := p.rewriteFunc(target, "/app", false)
|
||||
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
|
||||
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||
rewrite := p.rewriteFunc(target, "/app", false)
|
||||
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
@@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
@@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
|
||||
@@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
|
||||
@@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Proto", "https")
|
||||
@@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Port", "8443")
|
||||
@@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
@@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
@@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
|
||||
@@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
@@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
|
||||
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
@@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
|
||||
// Management builds: path="/heise", target="https://heise.de:443/heise"
|
||||
target, _ := url.Parse("https://heise.de:443/heise")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
|
||||
t.Run("subpath under prefix also preserved", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443/heise")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
// What the behavior WOULD be if target URL had no path (true stripping)
|
||||
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
|
||||
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
// Root path "/" — no stripping expected
|
||||
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.com:443/")
|
||||
rewrite := p.rewriteFunc(target, "/", false)
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
@@ -546,6 +546,82 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_PreservePath(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
t.Run("preserve keeps full request path", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/api/users/123", pr.Out.URL.Path,
|
||||
"preserve should keep the full original request path")
|
||||
})
|
||||
|
||||
t.Run("preserve with root matchedPath", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/anything", pr.Out.URL.Path)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_CustomHeaders(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
t.Run("injects custom headers", func(t *testing.T) {
|
||||
headers := map[string]string{
|
||||
"X-Custom-Auth": "token-abc",
|
||||
"X-Env": "production",
|
||||
}
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "token-abc", pr.Out.Header.Get("X-Custom-Auth"))
|
||||
assert.Equal(t, "production", pr.Out.Header.Get("X-Env"))
|
||||
})
|
||||
|
||||
t.Run("nil customHeaders is fine", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
|
||||
})
|
||||
|
||||
t.Run("custom headers override existing request headers", func(t *testing.T) {
|
||||
headers := map[string]string{"X-Override": "new-value"}
|
||||
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
pr.In.Header.Set("X-Override", "old-value")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "new-value", pr.Out.Header.Get("X-Override"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"})
|
||||
pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/api/deep/path", pr.Out.URL.Path, "preserve should keep the full original path")
|
||||
assert.Equal(t, "proxy", pr.Out.Header.Get("X-Via"), "custom header should be set")
|
||||
}
|
||||
|
||||
func TestRewriteLocationFunc(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }
|
||||
|
||||
@@ -6,21 +6,41 @@ import (
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
// PathRewriteMode controls how the request path is rewritten before forwarding.
|
||||
type PathRewriteMode int
|
||||
|
||||
const (
|
||||
// PathRewriteDefault strips the matched prefix and joins with the target path.
|
||||
PathRewriteDefault PathRewriteMode = iota
|
||||
// PathRewritePreserve keeps the full original request path as-is.
|
||||
PathRewritePreserve
|
||||
)
|
||||
|
||||
// PathTarget holds a backend URL and per-target behavioral options.
|
||||
type PathTarget struct {
|
||||
URL *url.URL
|
||||
SkipTLSVerify bool
|
||||
RequestTimeout time.Duration
|
||||
PathRewrite PathRewriteMode
|
||||
CustomHeaders map[string]string
|
||||
}
|
||||
|
||||
type Mapping struct {
|
||||
ID string
|
||||
AccountID types.AccountID
|
||||
Host string
|
||||
Paths map[string]*url.URL
|
||||
Paths map[string]*PathTarget
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
}
|
||||
|
||||
type targetResult struct {
|
||||
url *url.URL
|
||||
target *PathTarget
|
||||
matchedPath string
|
||||
serviceID string
|
||||
accountID types.AccountID
|
||||
@@ -55,10 +75,14 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
|
||||
|
||||
for _, path := range paths {
|
||||
if strings.HasPrefix(req.URL.Path, path) {
|
||||
target := m.Paths[path]
|
||||
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target)
|
||||
pt := m.Paths[path]
|
||||
if pt == nil || pt.URL == nil {
|
||||
p.logger.Warnf("invalid mapping for host: %s, path: %s (nil target)", host, path)
|
||||
continue
|
||||
}
|
||||
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, pt.URL)
|
||||
return targetResult{
|
||||
url: target,
|
||||
target: pt,
|
||||
matchedPath: path,
|
||||
serviceID: m.ID,
|
||||
accountID: m.AccountID,
|
||||
|
||||
32
proxy/internal/roundtrip/context_test.go
Normal file
32
proxy/internal/roundtrip/context_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
func TestAccountIDContext(t *testing.T) {
|
||||
t.Run("returns empty when missing", func(t *testing.T) {
|
||||
assert.Equal(t, types.AccountID(""), AccountIDFromContext(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("round-trips value", func(t *testing.T) {
|
||||
ctx := WithAccountID(context.Background(), "acc-123")
|
||||
assert.Equal(t, types.AccountID("acc-123"), AccountIDFromContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSkipTLSVerifyContext(t *testing.T) {
|
||||
t.Run("false by default", func(t *testing.T) {
|
||||
assert.False(t, skipTLSVerifyFromContext(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("true when set", func(t *testing.T) {
|
||||
ctx := WithSkipTLSVerify(context.Background())
|
||||
assert.True(t, skipTLSVerifyFromContext(ctx))
|
||||
})
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -52,9 +53,12 @@ type domainNotification struct {
|
||||
type clientEntry struct {
|
||||
client *embed.Client
|
||||
transport *http.Transport
|
||||
domains map[domain.Domain]domainInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// insecureTransport is a clone of transport with TLS verification disabled,
|
||||
// used when per-target skip_tls_verify is set.
|
||||
insecureTransport *http.Transport
|
||||
domains map[domain.Domain]domainInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// Per-backend in-flight limiting keyed by target host:port.
|
||||
// TODO: clean up stale entries when backend targets change.
|
||||
inflightMu sync.Mutex
|
||||
@@ -130,6 +134,9 @@ type ClientDebugInfo struct {
|
||||
// accountIDContextKey is the context key for storing the account ID.
|
||||
type accountIDContextKey struct{}
|
||||
|
||||
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
|
||||
type skipTLSVerifyContextKey struct{}
|
||||
|
||||
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
|
||||
// one is created by authenticating with the management server using the provided token.
|
||||
// Multiple domains can share the same client.
|
||||
@@ -249,27 +256,33 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
// Create a transport using the client dialer. We do this instead of using
|
||||
// the client's HTTPClient to avoid issues with request validation that do
|
||||
// not work with reverse proxied requests.
|
||||
transport := &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: n.transportCfg.maxIdleConns,
|
||||
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
|
||||
IdleConnTimeout: n.transportCfg.idleConnTimeout,
|
||||
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
|
||||
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
|
||||
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
|
||||
WriteBufferSize: n.transportCfg.writeBufferSize,
|
||||
ReadBufferSize: n.transportCfg.readBufferSize,
|
||||
DisableCompression: n.transportCfg.disableCompression,
|
||||
}
|
||||
|
||||
insecureTransport := transport.Clone()
|
||||
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec
|
||||
|
||||
return &clientEntry{
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
||||
transport: &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: n.transportCfg.maxIdleConns,
|
||||
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
|
||||
IdleConnTimeout: n.transportCfg.idleConnTimeout,
|
||||
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
|
||||
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
|
||||
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
|
||||
WriteBufferSize: n.transportCfg.writeBufferSize,
|
||||
ReadBufferSize: n.transportCfg.readBufferSize,
|
||||
DisableCompression: n.transportCfg.disableCompression,
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
inflightMap: make(map[backendKey]chan struct{}),
|
||||
maxInflight: n.transportCfg.maxInflight,
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
||||
transport: transport,
|
||||
insecureTransport: insecureTransport,
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
inflightMap: make(map[backendKey]chan struct{}),
|
||||
maxInflight: n.transportCfg.maxInflight,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -373,6 +386,7 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
|
||||
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
insecureTransport := entry.insecureTransport
|
||||
delete(n.clients, accountID)
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
@@ -387,6 +401,7 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
|
||||
}
|
||||
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -415,6 +430,9 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
if skipTLSVerifyFromContext(req.Context()) {
|
||||
transport = entry.insecureTransport
|
||||
}
|
||||
n.clientsMux.RUnlock()
|
||||
|
||||
release, ok := entry.acquireInflight(req.URL.Host)
|
||||
@@ -457,6 +475,7 @@ func (n *NetBird) StopAll(ctx context.Context) error {
|
||||
var merr *multierror.Error
|
||||
for accountID, entry := range n.clients {
|
||||
entry.transport.CloseIdleConnections()
|
||||
entry.insecureTransport.CloseIdleConnections()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
@@ -579,3 +598,14 @@ func AccountIDFromContext(ctx context.Context) types.AccountID {
|
||||
}
|
||||
return accountID
|
||||
}
|
||||
|
||||
// WithSkipTLSVerify marks the context to use an insecure transport that skips
|
||||
// TLS certificate verification for the backend connection.
|
||||
func WithSkipTLSVerify(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, skipTLSVerifyContextKey{}, true)
|
||||
}
|
||||
|
||||
func skipTLSVerifyFromContext(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -116,6 +116,9 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create real users manager
|
||||
usersManager := users.NewManager(testStore)
|
||||
|
||||
@@ -131,6 +134,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(
|
||||
&testAccessLogManager{},
|
||||
tokenStore,
|
||||
pkceStore,
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
|
||||
@@ -720,7 +720,7 @@ func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
}
|
||||
|
||||
func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
paths := make(map[string]*url.URL)
|
||||
paths := make(map[string]*proxy.PathTarget)
|
||||
for _, pathMapping := range mapping.GetPath() {
|
||||
targetURL, err := url.Parse(pathMapping.GetTarget())
|
||||
if err != nil {
|
||||
@@ -734,7 +734,17 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
}).WithError(err).Error("failed to parse target URL for path, skipping")
|
||||
continue
|
||||
}
|
||||
paths[pathMapping.GetPath()] = targetURL
|
||||
|
||||
pt := &proxy.PathTarget{URL: targetURL}
|
||||
if opts := pathMapping.GetOptions(); opts != nil {
|
||||
pt.SkipTLSVerify = opts.GetSkipTlsVerify()
|
||||
pt.PathRewrite = protoToPathRewrite(opts.GetPathRewrite())
|
||||
pt.CustomHeaders = opts.GetCustomHeaders()
|
||||
if d := opts.GetRequestTimeout(); d != nil {
|
||||
pt.RequestTimeout = d.AsDuration()
|
||||
}
|
||||
}
|
||||
paths[pathMapping.GetPath()] = pt
|
||||
}
|
||||
return proxy.Mapping{
|
||||
ID: mapping.GetId(),
|
||||
@@ -746,6 +756,15 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||
}
|
||||
}
|
||||
|
||||
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {
|
||||
switch mode {
|
||||
case proto.PathRewriteMode_PATH_REWRITE_PRESERVE:
|
||||
return proxy.PathRewritePreserve
|
||||
default:
|
||||
return proxy.PathRewriteDefault
|
||||
}
|
||||
}
|
||||
|
||||
// debugEndpointAddr returns the address for the debug endpoint.
|
||||
// If addr is empty, it defaults to localhost:8444 for security.
|
||||
func debugEndpointAddr(addr string) string {
|
||||
|
||||
271
shared/management/client/rest/reverse_proxy_services_test.go
Normal file
271
shared/management/client/rest/reverse_proxy_services_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
//go:build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
var testServiceTarget = api.ServiceTarget{
|
||||
TargetId: "peer-123",
|
||||
TargetType: "peer",
|
||||
Protocol: "https",
|
||||
Port: 8443,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
var testService = api.Service{
|
||||
Id: "svc-1",
|
||||
Name: "test-service",
|
||||
Domain: "test.example.com",
|
||||
Enabled: true,
|
||||
Auth: api.ServiceAuthConfig{},
|
||||
Meta: api.ServiceMeta{
|
||||
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
Status: "active",
|
||||
},
|
||||
Targets: []api.ServiceTarget{testServiceTarget},
|
||||
}
|
||||
|
||||
func TestReverseProxyServices_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.Service{testService})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.ReverseProxyServices.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, ret, 1)
|
||||
assert.Equal(t, testService.Id, ret[0].Id)
|
||||
assert.Equal(t, testService.Name, ret[0].Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxyServices_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.ReverseProxyServices.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxyServices_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testService)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.ReverseProxyServices.Get(context.Background(), "svc-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testService.Id, ret.Id)
|
||||
assert.Equal(t, testService.Domain, ret.Domain)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxyServices_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.ReverseProxyServices.Get(context.Background(), "svc-1")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxyServices_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.ServiceRequest
|
||||
require.NoError(t, json.Unmarshal(reqBytes, &req))
|
||||
assert.Equal(t, "test-service", req.Name)
|
||||
assert.Equal(t, "test.example.com", req.Domain)
|
||||
retBytes, _ := json.Marshal(testService)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
|
||||
Name: "test-service",
|
||||
Domain: "test.example.com",
|
||||
Enabled: true,
|
||||
Auth: api.ServiceAuthConfig{},
|
||||
Targets: []api.ServiceTarget{testServiceTarget},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testService.Id, ret.Id)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxyServices_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
|
||||
Name: "test-service",
|
||||
Domain: "test.example.com",
|
||||
Enabled: true,
|
||||
Auth: api.ServiceAuthConfig{},
|
||||
Targets: []api.ServiceTarget{testServiceTarget},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxyServices_Create_WithPerTargetOptions(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.ServiceRequest
|
||||
require.NoError(t, json.Unmarshal(reqBytes, &req))
|
||||
|
||||
require.Len(t, req.Targets, 1)
|
||||
target := req.Targets[0]
|
||||
require.NotNil(t, target.Options, "options should be present")
|
||||
opts := target.Options
|
||||
require.NotNil(t, opts.SkipTlsVerify, "skip_tls_verify should be present")
|
||||
assert.True(t, *opts.SkipTlsVerify)
|
||||
require.NotNil(t, opts.RequestTimeout, "request_timeout should be present")
|
||||
assert.Equal(t, "30s", *opts.RequestTimeout)
|
||||
require.NotNil(t, opts.PathRewrite, "path_rewrite should be present")
|
||||
assert.Equal(t, api.ServiceTargetOptionsPathRewrite("preserve"), *opts.PathRewrite)
|
||||
require.NotNil(t, opts.CustomHeaders, "custom_headers should be present")
|
||||
assert.Equal(t, "bar", (*opts.CustomHeaders)["X-Foo"])
|
||||
|
||||
retBytes, _ := json.Marshal(testService)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
pathRewrite := api.ServiceTargetOptionsPathRewrite("preserve")
|
||||
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
|
||||
Name: "test-service",
|
||||
Domain: "test.example.com",
|
||||
Enabled: true,
|
||||
Auth: api.ServiceAuthConfig{},
|
||||
Targets: []api.ServiceTarget{
|
||||
{
|
||||
TargetId: "peer-123",
|
||||
TargetType: "peer",
|
||||
Protocol: "https",
|
||||
Port: 8443,
|
||||
Enabled: true,
|
||||
Options: &api.ServiceTargetOptions{
|
||||
SkipTlsVerify: ptr(true),
|
||||
RequestTimeout: ptr("30s"),
|
||||
PathRewrite: &pathRewrite,
|
||||
CustomHeaders: &map[string]string{"X-Foo": "bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testService.Id, ret.Id)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxyServices_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.ServiceRequest
|
||||
require.NoError(t, json.Unmarshal(reqBytes, &req))
|
||||
assert.Equal(t, "updated-service", req.Name)
|
||||
retBytes, _ := json.Marshal(testService)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.ReverseProxyServices.Update(context.Background(), "svc-1", api.PutApiReverseProxiesServicesServiceIdJSONRequestBody{
|
||||
Name: "updated-service",
|
||||
Domain: "test.example.com",
|
||||
Enabled: true,
|
||||
Auth: api.ServiceAuthConfig{},
|
||||
Targets: []api.ServiceTarget{testServiceTarget},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testService.Id, ret.Id)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxyServices_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.ReverseProxyServices.Update(context.Background(), "svc-1", api.PutApiReverseProxiesServicesServiceIdJSONRequestBody{
|
||||
Name: "updated-service",
|
||||
Domain: "test.example.com",
|
||||
Enabled: true,
|
||||
Auth: api.ServiceAuthConfig{},
|
||||
Targets: []api.ServiceTarget{testServiceTarget},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxyServices_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.ReverseProxyServices.Delete(context.Background(), "svc-1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxyServices_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.ReverseProxyServices.Delete(context.Background(), "svc-1")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
@@ -3027,6 +3027,28 @@ components:
|
||||
- targets
|
||||
- auth
|
||||
- enabled
|
||||
ServiceTargetOptions:
|
||||
type: object
|
||||
properties:
|
||||
skip_tls_verify:
|
||||
type: boolean
|
||||
description: Skip TLS certificate verification for this backend
|
||||
request_timeout:
|
||||
type: string
|
||||
description: Per-target response timeout as a Go duration string (e.g. "30s", "2m")
|
||||
path_rewrite:
|
||||
type: string
|
||||
description: Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path.
|
||||
enum: [preserve]
|
||||
custom_headers:
|
||||
type: object
|
||||
description: Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected.
|
||||
propertyNames:
|
||||
type: string
|
||||
pattern: '^[!#$%&''*+.^_`|~0-9A-Za-z-]+$'
|
||||
additionalProperties:
|
||||
type: string
|
||||
pattern: '^[^\r\n]*$'
|
||||
ServiceTarget:
|
||||
type: object
|
||||
properties:
|
||||
@@ -3053,6 +3075,8 @@ components:
|
||||
enabled:
|
||||
type: boolean
|
||||
description: Whether this target is enabled
|
||||
options:
|
||||
$ref: '#/components/schemas/ServiceTargetOptions'
|
||||
required:
|
||||
- target_id
|
||||
- target_type
|
||||
|
||||
@@ -326,6 +326,11 @@ const (
|
||||
ServiceTargetTargetTypeResource ServiceTargetTargetType = "resource"
|
||||
)
|
||||
|
||||
// Defines values for ServiceTargetOptionsPathRewrite.
|
||||
const (
|
||||
ServiceTargetOptionsPathRewritePreserve ServiceTargetOptionsPathRewrite = "preserve"
|
||||
)
|
||||
|
||||
// Defines values for TenantResponseStatus.
|
||||
const (
|
||||
TenantResponseStatusActive TenantResponseStatus = "active"
|
||||
@@ -367,6 +372,27 @@ const (
|
||||
GetApiEventsNetworkTrafficParamsDirectionINGRESS GetApiEventsNetworkTrafficParamsDirection = "INGRESS"
|
||||
)
|
||||
|
||||
// Defines values for GetApiEventsProxyParamsSortBy.
|
||||
const (
|
||||
GetApiEventsProxyParamsSortByAuthMethod GetApiEventsProxyParamsSortBy = "auth_method"
|
||||
GetApiEventsProxyParamsSortByDuration GetApiEventsProxyParamsSortBy = "duration"
|
||||
GetApiEventsProxyParamsSortByHost GetApiEventsProxyParamsSortBy = "host"
|
||||
GetApiEventsProxyParamsSortByMethod GetApiEventsProxyParamsSortBy = "method"
|
||||
GetApiEventsProxyParamsSortByPath GetApiEventsProxyParamsSortBy = "path"
|
||||
GetApiEventsProxyParamsSortByReason GetApiEventsProxyParamsSortBy = "reason"
|
||||
GetApiEventsProxyParamsSortBySourceIp GetApiEventsProxyParamsSortBy = "source_ip"
|
||||
GetApiEventsProxyParamsSortByStatusCode GetApiEventsProxyParamsSortBy = "status_code"
|
||||
GetApiEventsProxyParamsSortByTimestamp GetApiEventsProxyParamsSortBy = "timestamp"
|
||||
GetApiEventsProxyParamsSortByUrl GetApiEventsProxyParamsSortBy = "url"
|
||||
GetApiEventsProxyParamsSortByUserId GetApiEventsProxyParamsSortBy = "user_id"
|
||||
)
|
||||
|
||||
// Defines values for GetApiEventsProxyParamsSortOrder.
|
||||
const (
|
||||
GetApiEventsProxyParamsSortOrderAsc GetApiEventsProxyParamsSortOrder = "asc"
|
||||
GetApiEventsProxyParamsSortOrderDesc GetApiEventsProxyParamsSortOrder = "desc"
|
||||
)
|
||||
|
||||
// Defines values for GetApiEventsProxyParamsMethod.
|
||||
const (
|
||||
GetApiEventsProxyParamsMethodDELETE GetApiEventsProxyParamsMethod = "DELETE"
|
||||
@@ -2741,7 +2767,8 @@ type ServiceTarget struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Host Backend ip or domain for this target
|
||||
Host *string `json:"host,omitempty"`
|
||||
Host *string `json:"host,omitempty"`
|
||||
Options *ServiceTargetOptions `json:"options,omitempty"`
|
||||
|
||||
// Path URL path prefix for this target
|
||||
Path *string `json:"path,omitempty"`
|
||||
@@ -2765,6 +2792,24 @@ type ServiceTargetProtocol string
|
||||
// ServiceTargetTargetType Target type (e.g., "peer", "resource")
|
||||
type ServiceTargetTargetType string
|
||||
|
||||
// ServiceTargetOptions defines model for ServiceTargetOptions.
|
||||
type ServiceTargetOptions struct {
|
||||
// CustomHeaders Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected.
|
||||
CustomHeaders *map[string]string `json:"custom_headers,omitempty"`
|
||||
|
||||
// PathRewrite Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path.
|
||||
PathRewrite *ServiceTargetOptionsPathRewrite `json:"path_rewrite,omitempty"`
|
||||
|
||||
// RequestTimeout Per-target response timeout as a Go duration string (e.g. "30s", "2m")
|
||||
RequestTimeout *string `json:"request_timeout,omitempty"`
|
||||
|
||||
// SkipTlsVerify Skip TLS certificate verification for this backend
|
||||
SkipTlsVerify *bool `json:"skip_tls_verify,omitempty"`
|
||||
}
|
||||
|
||||
// ServiceTargetOptionsPathRewrite Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path.
|
||||
type ServiceTargetOptionsPathRewrite string
|
||||
|
||||
// SetupKey defines model for SetupKey.
|
||||
type SetupKey struct {
|
||||
// AllowExtraDnsLabels Allow extra DNS labels to be added to the peer
|
||||
@@ -3335,6 +3380,12 @@ type GetApiEventsProxyParams struct {
|
||||
// PageSize Number of items per page (max 100)
|
||||
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
|
||||
|
||||
// SortBy Field to sort by (url sorts by host then path)
|
||||
SortBy *GetApiEventsProxyParamsSortBy `form:"sort_by,omitempty" json:"sort_by,omitempty"`
|
||||
|
||||
// SortOrder Sort order (ascending or descending)
|
||||
SortOrder *GetApiEventsProxyParamsSortOrder `form:"sort_order,omitempty" json:"sort_order,omitempty"`
|
||||
|
||||
// Search General search across request ID, host, path, source IP, user email, and user name
|
||||
Search *string `form:"search,omitempty" json:"search,omitempty"`
|
||||
|
||||
@@ -3372,6 +3423,12 @@ type GetApiEventsProxyParams struct {
|
||||
EndDate *time.Time `form:"end_date,omitempty" json:"end_date,omitempty"`
|
||||
}
|
||||
|
||||
// GetApiEventsProxyParamsSortBy defines parameters for GetApiEventsProxy.
|
||||
type GetApiEventsProxyParamsSortBy string
|
||||
|
||||
// GetApiEventsProxyParamsSortOrder defines parameters for GetApiEventsProxy.
|
||||
type GetApiEventsProxyParamsSortOrder string
|
||||
|
||||
// GetApiEventsProxyParamsMethod defines parameters for GetApiEventsProxy.
|
||||
type GetApiEventsProxyParamsMethod string
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ package management;
|
||||
|
||||
option go_package = "/proto";
|
||||
|
||||
import "google/protobuf/duration.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
// ProxyService - Management is the SERVER, Proxy is the CLIENT
|
||||
@@ -50,9 +51,22 @@ enum ProxyMappingUpdateType {
|
||||
UPDATE_TYPE_REMOVED = 2;
|
||||
}
|
||||
|
||||
enum PathRewriteMode {
|
||||
PATH_REWRITE_DEFAULT = 0;
|
||||
PATH_REWRITE_PRESERVE = 1;
|
||||
}
|
||||
|
||||
message PathTargetOptions {
|
||||
bool skip_tls_verify = 1;
|
||||
google.protobuf.Duration request_timeout = 2;
|
||||
PathRewriteMode path_rewrite = 3;
|
||||
map<string, string> custom_headers = 4;
|
||||
}
|
||||
|
||||
message PathMapping {
|
||||
string path = 1;
|
||||
string target = 2;
|
||||
PathTargetOptions options = 3;
|
||||
}
|
||||
|
||||
message Authentication {
|
||||
|
||||
Reference in New Issue
Block a user