Compare commits

...

24 Commits

Author SHA1 Message Date
crn4
a7d244230f Merge remote-tracking branch 'origin/main' into feat/byod-proxy
# Conflicts:
#	management/internals/modules/reverseproxy/proxy/manager.go
#	management/internals/modules/reverseproxy/proxy/manager/manager.go
#	management/internals/modules/reverseproxy/proxy/manager_mock.go
#	management/internals/shared/grpc/proxy.go
#	management/server/store/sql_store.go
#	management/server/store/store.go
#	management/server/store/store_mock.go
#	proxy/management_integration_test.go
2026-05-06 17:33:13 +02:00
Zoltan Papp
f23aaa9ae7 [client] iOS: structured ResolvedIPs collection for domain routes (#6090)
* [client] iOS: structured ResolvedIPs collection for domain routes

Replace comma-joined ResolvedIPs string with a gomobile-friendly
ResolvedIPs collection (Add/Get/Size), mirroring the Android bridge
in client/android/network_domains.go.

This allows the iOS app to match domain-route resolved IPs against
connected peer routes without parsing CSV strings, fixing the route
status indicator for dynamic (DNS) routes.

* [client] iOS: align dynamic route exposure with Android bridge

For dynamic (DNS) routes the Swift side previously received
"invalid Prefix" as the Network value, forcing UI code to special-case
that sentinel. The Android bridge uses Domains.SafeString() instead so
peer.routes entries (which also derive from Domains.SafeString()) match
directly. Mirror that here.

Also fix the resolved IP lookup: resolvedDomains is keyed by the
resolved domain (e.g. api.ipify.org), not the configured pattern
(e.g. *.ipify.org). Group entries by ParentDomain like the daemon does
in client/server/network.go, so wildcard route patterns get their
resolved IPs populated.
2026-05-06 17:14:11 +02:00
Viktor Liu
f532976e05 [client] Add public key to debug bundle config.txt (#6092) 2026-05-06 13:42:47 +02:00
Viktor Liu
71a400f90f [client] Include MTU and SSH auth/JWT cache config in debug bundle (#6071) 2026-05-06 13:23:43 +02:00
Pascal Fischer
bfeb9b19ec [management] remove permissions from geolocations api (#6091) 2026-05-06 13:07:01 +02:00
Pascal Fischer
b19b7464ea [management] fix flaky invite token test (#6077) 2026-05-05 18:48:51 +02:00
Pascal Fischer
cfb1b3fe31 [proxy] consolidate mapping update (#6072) 2026-05-05 18:40:42 +02:00
Bethuel Mmbaga
3c28d29725 [management] Map Entra oid claim as Dex user ID (#6067) 2026-05-05 18:12:18 +03:00
crn4
49f170c21e fixed issue with byop shared accross accounts 2026-04-28 18:52:20 +02:00
crn4
f78bc6113e test fix 2026-04-28 17:59:43 +02:00
crn4
4f0d6ef8f9 delete cluster + fix front issues 2026-04-28 16:46:54 +02:00
crn4
e6a663ba20 merge main 2026-04-27 18:03:11 +02:00
crn4
7c4011d8e2 remove withContext from all sql calls 2026-04-27 17:28:56 +02:00
crn4
8fe2b5ec1e removed condition on 1 yop per account 2026-04-21 15:12:52 +02:00
crn4
e62521132c Merge remote-tracking branch 'origin/main' into feat/byod-proxy
# Conflicts:
#	management/internals/modules/reverseproxy/domain/manager/manager.go
#	management/internals/modules/reverseproxy/proxy/manager.go
#	management/internals/modules/reverseproxy/proxy/manager/manager.go
#	management/internals/modules/reverseproxy/proxy/manager_mock.go
#	management/internals/shared/grpc/proxy.go
#	management/server/store/sql_store.go
#	proxy/management_integration_test.go
2026-04-13 17:02:24 +03:00
crn4
de3cb06067 added proxy id to cluster api response 2026-03-31 00:28:19 +02:00
crn4
4fdc39c8f8 review comments 2026-03-24 15:37:31 +01:00
crn4
94149a9441 linter 2026-03-24 14:58:03 +01:00
crn4
38fd73fad6 merge main 2026-03-24 14:50:03 +01:00
crn4
9dd76b5a07 merge main 2026-03-24 14:20:03 +01:00
crn4
0b5380a7dc review comments 2026-03-24 13:32:38 +01:00
crn4
177171e437 change api 2026-03-19 21:49:04 +01:00
crn4
da57b0f276 rename byod to byop 2026-03-19 16:11:57 +01:00
crn4
26ba03f08e [proxy] feature: bring your own proxy 2026-03-19 01:02:46 +01:00
43 changed files with 3363 additions and 251 deletions

View File

@@ -21,6 +21,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize"
@@ -583,6 +584,9 @@ func isSensitiveEnvVar(key string) bool {
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")
if key, err := wgtypes.ParseKey(g.internalConfig.PrivateKey); err == nil {
configContent.WriteString(fmt.Sprintf("PublicKey: %s\n", key.PublicKey().String()))
}
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
if g.internalConfig.NetworkMonitor != nil {
@@ -607,6 +611,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
}
if g.internalConfig.DisableSSHAuth != nil {
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
}
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
@@ -633,6 +643,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
}
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
}
func (g *BundleGenerator) addProf() (err error) {

View File

@@ -5,16 +5,21 @@ import (
"bytes"
"encoding/json"
"net"
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -471,8 +476,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"HOME": "/root",
"PATH": "/usr/bin",
"HOME": "/root",
"PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug",
},
},
@@ -489,9 +494,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
},
},
check: func(t *testing.T, params map[string]any) {
@@ -766,3 +771,127 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
}
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
// profilemanager.Config is either rendered in the debug bundle or explicitly
// excluded. When a new field is added to Config, this test fails until the
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
// the excluded set with a justification.
func TestAddConfig_AllFieldsCovered(t *testing.T) {
excluded := map[string]string{
"PrivateKey": "sensitive: WireGuard private key",
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
}
mURL, _ := url.Parse("https://api.example.com:443")
aURL, _ := url.Parse("https://admin.example.com:443")
bTrue := true
iVal := 42
cfg := &profilemanager.Config{
PrivateKey: "priv",
PreSharedKey: "psk",
ManagementURL: mURL,
AdminURL: aURL,
WgIface: "wt0",
WgPort: 51820,
NetworkMonitor: &bTrue,
IFaceBlackList: []string{"eth0"},
DisableIPv6Discovery: true,
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,
EnableSSHRemotePortForwarding: &bTrue,
DisableSSHAuth: &bTrue,
SSHJWTCacheTTL: &iVal,
DisableClientRoutes: true,
DisableServerRoutes: true,
DisableDNS: true,
DisableFirewall: true,
BlockLANAccess: true,
BlockInbound: true,
DisableNotifications: &bTrue,
DNSLabels: domain.List{},
SSHKey: "sshkey",
NATExternalIPs: []string{"1.2.3.4"},
CustomDNSAddress: "1.1.1.1:53",
DisableAutoConnect: true,
DNSRouteInterval: 5 * time.Second,
ClientCertPath: "/tmp/cert",
ClientCertKeyPath: "/tmp/key",
LazyConnectionEnabled: true,
MTU: 1280,
}
for _, anonymize := range []bool{false, true} {
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
g := &BundleGenerator{
anonymizer: newAnonymizerForTest(),
internalConfig: cfg,
anonymize: anonymize,
}
var sb strings.Builder
g.addCommonConfigFields(&sb)
rendered := sb.String() + renderAddConfigSpecific(g)
val := reflect.ValueOf(cfg).Elem()
typ := val.Type()
var missing []string
for i := 0; i < typ.NumField(); i++ {
name := typ.Field(i).Name
if _, ok := excluded[name]; ok {
continue
}
if !strings.Contains(rendered, name+":") {
missing = append(missing, name)
}
}
if len(missing) > 0 {
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
"Either render the field in addCommonConfigFields/addConfig, "+
"or add it to the excluded map with a justification.", missing)
}
})
}
}
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
// production shape without needing to write an actual zip.
func renderAddConfigSpecific(g *BundleGenerator) string {
var sb strings.Builder
if g.anonymize {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
}
} else {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
}
}
return sb.String()
}
func newAnonymizerForTest() *anonymize.Anonymizer {
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
}

View File

@@ -413,25 +413,40 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
var routeSelection []RoutesSelectionInfo
for _, r := range routes {
domainList := make([]DomainInfo, 0)
// resolvedDomains is keyed by the resolved domain (e.g. api.ipify.org),
// not the configured pattern (e.g. *.ipify.org). Group entries whose
// ParentDomain belongs to this route, mirroring the daemon logic in
// client/server/network.go.
domainList := make([]DomainInfo, 0, len(r.Domains))
domainIndex := make(map[domain.Domain]int, len(r.Domains))
for _, d := range r.Domains {
domainResp := DomainInfo{
Domain: d.SafeString(),
}
if info, exists := resolvedDomains[d]; exists {
var ipStrings []string
for _, prefix := range info.Prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
}
domainList = append(domainList, domainResp)
domainIndex[d] = len(domainList)
domainList = append(domainList, DomainInfo{Domain: d.SafeString()})
}
for _, info := range resolvedDomains {
idx, ok := domainIndex[info.ParentDomain]
if !ok {
continue
}
for _, prefix := range info.Prefixes {
domainList[idx].AddResolvedIP(prefix.Addr().String())
}
}
domainDetails := DomainDetails{items: domainList}
// For dynamic (DNS) routes, expose the joined domain pattern as the
// Network value so it matches the peer.routes entries on the Swift
// side (mirroring the Android bridge in client/android/client.go).
netStr := r.Network.String()
if len(r.Domains) > 0 {
netStr = r.Domains.SafeString()
}
routeSelection = append(routeSelection, RoutesSelectionInfo{
ID: r.NetID,
Network: r.Network.String(),
Network: netStr,
Domains: &domainDetails,
Selected: r.Selected,
})

View File

@@ -34,7 +34,34 @@ type DomainDetails struct {
type DomainInfo struct {
Domain string
ResolvedIPs string
resolvedIPs ResolvedIPs
}
func (d *DomainInfo) AddResolvedIP(ipAddress string) {
d.resolvedIPs.Add(ipAddress)
}
func (d *DomainInfo) GetResolvedIPs() *ResolvedIPs {
return &d.resolvedIPs
}
type ResolvedIPs struct {
items []string
}
func (r *ResolvedIPs) Add(ipAddress string) {
r.items = append(r.items, ipAddress)
}
func (r *ResolvedIPs) Get(i int) string {
if i < 0 || i >= len(r.items) {
return ""
}
return r.items[i]
}
func (r *ResolvedIPs) Size() int {
return len(r.items)
}
// Add new PeerInfo to the collection

View File

@@ -89,21 +89,33 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro
}
// UpdateConnector updates an existing connector in Dex storage.
// It merges incoming updates with existing values to prevent data loss on partial updates.
// It overlays user-mutable config fields (issuer, clientID, clientSecret,
// redirectURI) onto the stored connector config, and updates the connector name
// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so
// partial updates preserve create-time defaults such as scopes, claimMapping,
// and userIDKey.
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
oldCfg, err := p.parseStorageConnector(old)
if err != nil {
return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) {
return storage.Connector{}, errors.New("connector type change not allowed")
}
mergeConnectorConfig(cfg, oldCfg)
storageConn, err := p.buildStorageConnector(cfg)
configData, err := overlayConnectorConfig(old.Config, cfg)
if err != nil {
return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err)
return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err)
}
return storageConn, nil
name := cfg.Name
if name == "" {
name = old.Name
}
return storage.Connector{
ID: cfg.ID,
Type: old.Type,
Name: name,
Config: configData,
}, nil
}); err != nil {
return fmt.Errorf("failed to update connector: %w", err)
}
@@ -112,23 +124,27 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er
return nil
}
// mergeConnectorConfig preserves existing values for empty fields in the update.
func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) {
if cfg.ClientSecret == "" {
cfg.ClientSecret = oldCfg.ClientSecret
// overlayConnectorConfig writes only the user-mutable fields onto the existing
// stored config, preserving every other field (scopes, claimMapping, userIDKey,
// insecure flags, etc.). Empty fields on cfg leave the existing value alone.
func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) {
var m map[string]any
if err := decodeConnectorConfig(oldConfig, &m); err != nil {
return nil, err
}
if cfg.RedirectURI == "" {
cfg.RedirectURI = oldCfg.RedirectURI
if cfg.Issuer != "" {
m["issuer"] = cfg.Issuer
}
if cfg.Issuer == "" && cfg.Type == oldCfg.Type {
cfg.Issuer = oldCfg.Issuer
if cfg.ClientID != "" {
m["clientID"] = cfg.ClientID
}
if cfg.ClientID == "" {
cfg.ClientID = oldCfg.ClientID
if cfg.ClientSecret != "" {
m["clientSecret"] = cfg.ClientSecret
}
if cfg.Name == "" {
cfg.Name = oldCfg.Name
if cfg.RedirectURI != "" {
m["redirectURI"] = cfg.RedirectURI
}
return encodeConnectorConfig(m)
}
// DeleteConnector removes a connector from Dex storage.
@@ -216,6 +232,10 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
oidcConfig["getUserInfo"] = true
case "entra":
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
// Use the Entra Object ID (oid) instead of the default OIDC sub claim.
// Entra issues sub as a per-app pairwise identifier that does not match
// the stable Object ID.
oidcConfig["userIDKey"] = "oid"
case "okta":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":

205
idp/dex/connector_test.go Normal file
View File

@@ -0,0 +1,205 @@
package dex
import (
"context"
"encoding/json"
"log/slog"
"os"
"path/filepath"
"testing"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestProvider(t *testing.T) (*Provider, func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "dex-connector-test-*")
require.NoError(t, err)
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
require.NoError(t, err)
return &Provider{storage: s, logger: logger}, func() {
_ = s.Close()
_ = os.RemoveAll(tmpDir)
}
}
func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) {
cfg := &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "client-secret",
}
data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(data, &m))
assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid")
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"])
}
func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) {
// ensures the Entra userIDKey override does not leak into other OIDC providers,
// which already use a stable sub claim.
for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} {
t.Run(typ, func(t *testing.T) {
data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(data, &m))
_, ok := m["userIDKey"]
assert.False(t, ok, "%s connectors must not have userIDKey set", typ)
})
}
}
func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
created, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "old-secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
require.Equal(t, "entra-test", created.ID)
// Rotate only the client secret.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "entra",
ClientSecret: "new-secret",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated")
assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)")
assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"])
assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update")
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update")
}
func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
// Seed a connector directly into storage without userIDKey
preFixConfig, err := json.Marshal(map[string]any{
"issuer": "https://login.microsoftonline.com/tid/v2.0",
"clientID": "client-id",
"clientSecret": "old-secret",
"redirectURI": "https://example.com/oauth2/callback",
"scopes": []string{"openid", "profile", "email"},
"claimMapping": map[string]string{"email": "preferred_username"},
})
require.NoError(t, err)
require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{
ID: "entra-prefix",
Type: "oidc",
Name: "Entra",
Config: preFixConfig,
}))
// Rotate client secret via UpdateConnector.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-prefix",
Type: "entra",
ClientSecret: "new-secret",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-prefix")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "new-secret", m["clientSecret"])
_, has := m["userIDKey"]
assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before")
}
func TestUpdateConnector_RejectsTypeChange(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
_, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/tid/v2.0",
ClientID: "client-id",
ClientSecret: "secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
// Attempt to switch the connector to okta.
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "okta",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "connector type change not allowed")
// stored connector type/config unchanged after the rejected update.
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
assert.Equal(t, "oidc", conn.Type)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "oid", m["userIDKey"])
}
func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) {
ctx := context.Background()
p, cleanup := newTestProvider(t)
defer cleanup()
_, err := p.CreateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Name: "Entra",
Type: "entra",
Issuer: "https://login.microsoftonline.com/old/v2.0",
ClientID: "client-id",
ClientSecret: "secret",
RedirectURI: "https://example.com/oauth2/callback",
})
require.NoError(t, err)
err = p.UpdateConnector(ctx, &ConnectorConfig{
ID: "entra-test",
Type: "entra",
Issuer: "https://login.microsoftonline.com/new/v2.0",
})
require.NoError(t, err)
conn, err := p.storage.GetConnector(ctx, "entra-test")
require.NoError(t, err)
var m map[string]any
require.NoError(t, json.Unmarshal(conn.Config, &m))
assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"])
}

View File

@@ -31,6 +31,7 @@ type store interface {
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
@@ -71,8 +72,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
var ret []*domain.Domain
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// For BYOP accounts, only their own cluster is returned; otherwise shared clusters.
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
@@ -126,8 +127,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// Verify the target cluster is in the available clusters for this account
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -273,7 +274,7 @@ func (m Manager) GetClusterDomains() []string {
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -298,6 +299,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
}
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
}
if len(byopAddresses) > 0 {
return byopAddresses, nil
}
return m.proxyManager.GetActiveClusterAddresses(ctx)
}
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
bestCluster := ""
bestLen := -1

View File

@@ -0,0 +1,110 @@
package manager
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockProxyManager struct {
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
}
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveClusterAddressesFunc != nil {
return m.getActiveClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveClusterAddressesForAccountFunc != nil {
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result)
}
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, errors.New("db error")
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "BYOP cluster addresses")
}
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
}

View File

@@ -11,15 +11,19 @@ import (
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error)
Disconnect(ctx context.Context, proxyID, sessionID string) error
Heartbeat(ctx context.Context, p *Proxy) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
}
// OIDCValidationConfig contains the OIDC configuration needed for token validation.

View File

@@ -16,11 +16,16 @@ type store interface {
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
}
// Manager handles all proxy operations
@@ -44,7 +49,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
// Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
now := time.Now()
var caps proxy.Capabilities
if capabilities != nil {
@@ -55,9 +60,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
SessionID: sessionID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
AccountID: accountID,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
Status: proxy.StatusConnected,
Capabilities: caps,
}
@@ -77,7 +83,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
}
// Disconnect marks a proxy as disconnected in the database.
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
return err
@@ -92,7 +98,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) erro
}
// Heartbeat updates the proxy's last seen timestamp.
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
return err
@@ -104,7 +110,7 @@ func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
@@ -113,16 +119,6 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error
return addresses, nil
}
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) {
clusters, err := m.store.GetActiveProxyClusters(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err)
return nil, err
}
return clusters, nil
}
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
// supports custom ports. Returns nil when no proxy has reported capabilities.
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
@@ -142,10 +138,44 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
return nil, err
}
return addresses, nil
}
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
return m.store.GetProxyByAccountID(ctx, accountID)
}
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
return m.store.CountProxiesByAccountID(ctx, accountID)
}
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
if err != nil {
return false, err
}
return !conflicting, nil
}
func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil {
log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err)
return err
}
return nil
}

View File

@@ -0,0 +1,337 @@
package manager
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
type mockStore struct {
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
}
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
if m.saveProxyFunc != nil {
return m.saveProxyFunc(ctx, p)
}
return nil
}
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
if m.disconnectProxyFunc != nil {
return m.disconnectProxyFunc(ctx, proxyID, sessionID)
}
return nil
}
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
if m.updateProxyHeartbeatFunc != nil {
return m.updateProxyHeartbeatFunc(ctx, p)
}
return nil
}
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveProxyClusterAddressesFunc != nil {
return m.getActiveProxyClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveProxyClusterAddressesForAccFunc != nil {
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
if m.cleanupStaleProxiesFunc != nil {
return m.cleanupStaleProxiesFunc(ctx, d)
}
return nil
}
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
if m.getProxyByAccountIDFunc != nil {
return m.getProxyByAccountIDFunc(ctx, accountID)
}
return nil, fmt.Errorf("proxy not found for account %s", accountID)
}
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
if m.countProxiesByAccountIDFunc != nil {
return m.countProxiesByAccountIDFunc(ctx, accountID)
}
return 0, nil
}
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
if m.isClusterAddressConflictingFunc != nil {
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
}
return false, nil
}
func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
if m.deleteAccountClusterFunc != nil {
return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID)
}
return nil
}
func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func newTestManager(s store) *Manager {
meter := noop.NewMeterProvider().Meter("test")
m, err := NewManager(s, meter)
if err != nil {
panic(err)
}
return m
}
func TestConnect_WithAccountID(t *testing.T) {
accountID := "acc-123"
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", &accountID, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Equal(t, "proxy-1", savedProxy.ID)
assert.Equal(t, "session-1", savedProxy.SessionID)
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
assert.Equal(t, &accountID, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
assert.NotNil(t, savedProxy.ConnectedAt)
}
func TestConnect_WithoutAccountID(t *testing.T) {
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Nil(t, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
}
func TestConnect_StoreError(t *testing.T) {
s := &mockStore{
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", nil, nil)
assert.Error(t, err)
}
func TestIsClusterAddressAvailable(t *testing.T) {
tests := []struct {
name string
conflicting bool
storeErr error
wantResult bool
wantErr bool
}{
{
name: "available - no conflict",
conflicting: false,
wantResult: true,
},
{
name: "not available - conflict exists",
conflicting: true,
wantResult: false,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
return tt.conflicting, tt.storeErr
},
}
mgr := newTestManager(s)
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantResult, result)
})
}
}
func TestCountAccountProxies(t *testing.T) {
tests := []struct {
name string
count int64
storeErr error
wantCount int64
wantErr bool
}{
{
name: "no proxies",
count: 0,
wantCount: 0,
},
{
name: "one proxy",
count: 1,
wantCount: 1,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
return tt.count, tt.storeErr
},
}
mgr := newTestManager(s)
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantCount, count)
})
}
}
func TestGetAccountProxy(t *testing.T) {
accountID := "acc-123"
t.Run("found", func(t *testing.T) {
expected := &proxy.Proxy{
ID: "proxy-1",
ClusterAddress: "byop.example.com",
AccountID: &accountID,
Status: proxy.StatusConnected,
}
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
assert.Equal(t, accountID, accID)
return expected, nil
},
}
mgr := newTestManager(s)
p, err := mgr.GetAccountProxy(context.Background(), accountID)
require.NoError(t, err)
assert.Equal(t, expected, p)
})
t.Run("not found", func(t *testing.T) {
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, errors.New("not found")
},
}
mgr := newTestManager(s)
_, err := mgr.GetAccountProxy(context.Background(), accountID)
assert.Error(t, err)
})
}
func TestDeleteAccountCluster(t *testing.T) {
t.Run("success", func(t *testing.T) {
var deletedCluster, deletedAccount string
s := &mockStore{
deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error {
deletedCluster = clusterAddress
deletedAccount = accountID
return nil
},
}
mgr := newTestManager(s)
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
require.NoError(t, err)
assert.Equal(t, "cluster.example.com", deletedCluster)
assert.Equal(t, "acc-123", deletedAccount)
})
t.Run("store error", func(t *testing.T) {
s := &mockStore{
deleteAccountClusterFunc: func(_ context.Context, _, _ string) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
assert.Error(t, err)
})
}
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
expected := []string{"byop.example.com"}
s := &mockStore{
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return expected, nil
},
}
mgr := newTestManager(s)
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, expected, result)
}

View File

@@ -93,18 +93,18 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
ret0, _ := ret[0].(*Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
}
// Disconnect mocks base method.
@@ -136,19 +136,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// GetActiveClusters mocks base method.
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
ret0, _ := ret[0].([]Cluster)
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusters indicates an expected call of GetActiveClusters.
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
}
// Heartbeat mocks base method.
@@ -165,6 +163,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
}
// GetAccountProxy mocks base method.
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
ret0, _ := ret[0].(*Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAccountProxy indicates an expected call of GetAccountProxy.
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
}
// CountAccountProxies mocks base method.
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountAccountProxies indicates an expected call of CountAccountProxies.
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
}
// IsClusterAddressAvailable mocks base method.
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
}
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller

View File

@@ -1,6 +1,13 @@
package proxy
import "time"
import (
"time"
)
const (
StatusConnected = "connected"
StatusDisconnected = "disconnected"
)
// Capabilities describes what a proxy can handle, as reported via gRPC.
// Nil fields mean the proxy never reported this capability.
@@ -21,6 +28,7 @@ type Proxy struct {
SessionID string `gorm:"type:varchar(36)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
@@ -36,6 +44,8 @@ func (Proxy) TableName() string {
// Cluster represents a group of proxy nodes serving the same address.
type Cluster struct {
ID string
Address string
ConnectedProxies int
SelfHosted bool
}

View File

@@ -0,0 +1,195 @@
package proxytoken
import (
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
store store.Store
permissionsManager permissions.Manager
}
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
h := &handler{store: s, permissionsManager: permissionsManager}
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
}
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
var req api.ProxyTokenRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" || len(req.Name) > 255 {
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
return
}
var expiresIn time.Duration
if req.ExpiresIn != nil {
if *req.ExpiresIn < 0 {
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
return
}
if *req.ExpiresIn > 0 {
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
}
}
accountID := userAuth.AccountId
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
if err != nil {
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
return
}
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
return
}
resp := toProxyTokenCreatedResponse(generated)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
return
}
resp := make([]api.ProxyToken, 0, len(tokens))
for _, token := range tokens {
resp = append(resp, toProxyTokenResponse(token))
}
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokenID := mux.Vars(r)["tokenId"]
if tokenID == "" {
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
return
}
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
if err != nil {
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
} else {
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
}
return
}
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
return
}
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
resp := api.ProxyToken{
Id: token.ID,
Name: token.Name,
Revoked: token.Revoked,
}
if !token.CreatedAt.IsZero() {
resp.CreatedAt = token.CreatedAt
}
if token.ExpiresAt != nil {
resp.ExpiresAt = token.ExpiresAt
}
if token.LastUsed != nil {
resp.LastUsed = token.LastUsed
}
return resp
}
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
base := toProxyTokenResponse(&generated.ProxyAccessToken)
plainToken := string(generated.PlainToken)
return api.ProxyTokenCreated{
Id: base.Id,
Name: base.Name,
CreatedAt: base.CreatedAt,
ExpiresAt: base.ExpiresAt,
LastUsed: base.LastUsed,
Revoked: base.Revoked,
PlainToken: plainToken,
}
}

View File

@@ -0,0 +1,275 @@
package proxytoken
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func authContext(accountID, userID string) context.Context {
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
AccountId: accountID,
UserId: userID,
})
}
func TestCreateToken_AccountScoped(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "my-token"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp api.ProxyTokenCreated
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
assert.NotEmpty(t, resp.PlainToken)
assert.Equal(t, "my-token", resp.Name)
assert.False(t, resp.Revoked)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.AccountID)
assert.Equal(t, accountID, *savedToken.AccountID)
assert.Equal(t, "user-1", savedToken.CreatedBy)
}
func TestCreateToken_WithExpiration(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "expiring-token", "expires_in": 3600}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.ExpiresAt)
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
}
func TestCreateToken_EmptyName(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": ""}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func TestCreateToken_PermissionDenied(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": "test"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
}
func TestListTokens(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
now := time.Now()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.listTokens(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp []api.ProxyToken
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
require.Len(t, resp, 2)
assert.Equal(t, "tok-1", resp[0].Id)
assert.False(t, resp[0].Revoked)
assert.Equal(t, "tok-2", resp[1].Id)
assert.True(t, resp[1].Revoked)
}
func TestRevokeToken_Success(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
Name: "test-token",
AccountID: &accountID,
}, nil)
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext(accountID, "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRevokeToken_WrongAccount(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
otherAccount := "acc-other"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: &otherAccount,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestRevokeToken_ManagementWideToken(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: nil,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}

View File

@@ -10,6 +10,7 @@ import (
type Manager interface {
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
@@ -28,4 +29,5 @@ type Manager interface {
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StartExposeReaper(ctx context.Context)
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
}

View File

@@ -79,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
@@ -138,6 +152,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
}
// GetServiceByDomain mocks base method.
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
}
// GetGlobalServices mocks base method.
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
m.ctrl.T.Helper()

View File

@@ -35,6 +35,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/clusters/{clusterAddress}", h.deleteCluster).Methods("DELETE", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
@@ -195,10 +196,33 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
for _, c := range clusters {
apiClusters = append(apiClusters, api.ProxyCluster{
Id: c.ID,
Address: c.Address,
ConnectedProxies: c.ConnectedProxies,
SelfHosted: c.SelfHosted,
})
}
util.WriteJSONObject(r.Context(), w, apiClusters)
}
func (h *handler) deleteCluster(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
clusterAddress := mux.Vars(r)["clusterAddress"]
if clusterAddress == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "cluster address is required"), w)
return
}
if err := h.manager.DeleteAccountCluster(r.Context(), userAuth.AccountId, userAuth.UserId, clusterAddress); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -121,7 +121,21 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
return nil, status.NewPermissionDeniedError()
}
return m.store.GetActiveProxyClusters(ctx)
return m.store.GetActiveProxyClusters(ctx, accountID)
}
// DeleteAccountCluster removes all proxy registrations for the given cluster address
// owned by the account.
func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
return m.store.DeleteAccountCluster(ctx, clusterAddress, accountID)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
@@ -985,6 +999,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
return services, nil
}
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {

View File

@@ -433,7 +433,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
return srv
}
@@ -712,7 +712,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1135,7 +1135,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)

View File

@@ -193,7 +193,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -9,8 +9,11 @@ import (
"encoding/hex"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -48,6 +51,11 @@ type ProxyOIDCConfig struct {
KeysLocation string
}
// ProxyTokenChecker checks whether a proxy access token is still valid.
type ProxyTokenChecker interface {
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
}
// ProxyServiceServer implements the ProxyService gRPC server
type ProxyServiceServer struct {
proto.UnimplementedProxyServiceServer
@@ -76,22 +84,56 @@ type ProxyServiceServer struct {
// Store for one-time authentication tokens
tokenStore *OneTimeTokenStore
// Checker for proxy access token validity
tokenChecker ProxyTokenChecker
// OIDC configuration for proxy authentication
oidcConfig ProxyOIDCConfig
// Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore
// tokenTTL is the lifetime of one-time tokens generated for proxy
// authentication. Defaults to defaultProxyTokenTTL when zero.
tokenTTL time.Duration
// snapshotBatchSize is the number of mappings per gRPC message during
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
snapshotBatchSize int
cancel context.CancelFunc
}
const pkceVerifierTTL = 10 * time.Minute
const defaultProxyTokenTTL = 5 * time.Minute
const defaultSnapshotBatchSize = 500
func snapshotBatchSizeFromEnv() int {
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultSnapshotBatchSize
}
// proxyTokenTTL returns the configured token TTL or the default when unset.
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
if s.tokenTTL > 0 {
return s.tokenTTL
}
return defaultProxyTokenTTL
}
// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
sessionID string
address string
accountID *string
tokenID string
capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.GetMappingUpdateResponse
@@ -99,8 +141,19 @@ type proxyConnection struct {
cancel context.CancelFunc
}
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
token := GetProxyTokenFromContext(ctx)
if token == nil || token.AccountID == nil {
return nil
}
if requestAccountID == "" || *token.AccountID != requestAccountID {
return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID)
}
return nil
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
@@ -110,6 +163,8 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
tokenChecker: tokenChecker,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
cancel: cancel,
}
go s.cleanupStaleProxies(ctx)
@@ -168,6 +223,25 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
var accountID *string
token := GetProxyTokenFromContext(ctx)
if token != nil && token.AccountID != nil {
accountID = token.AccountID
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
if err != nil {
return status.Errorf(codes.Internal, "check cluster address: %v", err)
}
if !available {
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
}
}
var tokenID string
if token != nil {
tokenID = token.ID
}
sessionID := uuid.NewString()
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
@@ -185,6 +259,8 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
proxyID: proxyID,
sessionID: sessionID,
address: proxyAddress,
accountID: accountID,
tokenID: tokenID,
capabilities: req.GetCapabilities(),
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
@@ -192,12 +268,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
cancel: cancel,
}
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database with capabilities
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
caps = &proxy.Capabilities{
@@ -206,21 +276,43 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
SupportsCrowdsec: c.SupportsCrowdsec,
}
}
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
if err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
s.connectedProxies.CompareAndDelete(proxyID, conn)
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
cancel()
if accountID != nil {
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
}
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
if err := s.sendSnapshot(ctx, conn); err != nil {
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
}
}
cancel()
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
}
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"session_id": sessionID,
"address": proxyAddress,
"cluster_addr": proxyAddress,
"account_id": accountID,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
@@ -241,14 +333,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}()
if err := s.sendSnapshot(ctx, conn); err != nil {
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
go s.heartbeat(connCtx, proxyRecord)
go s.heartbeat(connCtx, conn, proxyRecord)
select {
case err := <-errChan:
@@ -260,8 +345,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
}
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
// heartbeat updates the proxy's last_seen timestamp every minute and
// disconnects the proxy if its access token has been revoked.
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -271,6 +357,19 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
}
if conn.tokenID != "" && s.tokenChecker != nil {
valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID)
if err != nil {
log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err)
continue
}
if !valid {
log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID)
conn.cancel()
return
}
}
case <-ctx.Done():
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
return
@@ -278,8 +377,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
}
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only entries matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
@@ -290,29 +387,40 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return err
}
// Send mappings in batches to reduce per-message gRPC overhead while
// staying well within the default 4 MB message size limit.
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
end := i + s.snapshotBatchSize
if end > len(mappings) {
end = len(mappings)
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: mappings[i:end],
InitialSyncComplete: end == len(mappings),
}); err != nil {
return fmt.Errorf("send snapshot batch: %w", err)
}
}
if len(mappings) == 0 {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
InitialSyncComplete: true,
}); err != nil {
return fmt.Errorf("send snapshot completion: %w", err)
}
return nil
}
for i, m := range mappings {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{m},
InitialSyncComplete: i == len(mappings)-1,
}); err != nil {
return fmt.Errorf("send proxy mapping: %w", err)
}
}
return nil
}
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
var services []*rpservice.Service
var err error
if conn.accountID != nil {
services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID)
} else {
services, err = s.serviceManager.GetGlobalServices(ctx)
}
if err != nil {
return nil, fmt.Errorf("get services from store: %w", err)
}
@@ -323,13 +431,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
continue
}
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
if err != nil {
log.WithFields(log.Fields{
"service": service.Name,
"account": service.AccountID,
}).WithError(err).Error("failed to generate auth token for snapshot")
continue
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
}
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
@@ -341,8 +445,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
return mappings, nil
}
// isProxyAddressValid validates a proxy address
// isProxyAddressValid validates a proxy address (domain name or IP address)
func isProxyAddressValid(addr string) bool {
if addr == "" {
return false
}
if net.ParseIP(addr) != nil {
return true
}
_, err := domain.ValidateDomains([]string{addr})
return err == nil
}
@@ -366,6 +476,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
accessLog := req.GetLog()
if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil {
return nil, err
}
fields := log.Fields{
"service_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(),
@@ -403,24 +517,68 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
// Management should call this when services are created/updated/removed.
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
// BYOP proxies only receive updates for their own account's services.
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
log.Debugf("Broadcasting service update to all connected proxy servers")
updateAccountIDs := make(map[string]struct{})
for _, m := range update.Mapping {
if m.AccountId != "" {
updateAccountIDs[m.AccountId] = struct{}{}
}
}
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
resp := s.perProxyMessage(update, conn.proxyID)
connUpdate := update
if conn.accountID != nil && len(updateAccountIDs) > 0 {
if _, ok := updateAccountIDs[*conn.accountID]; !ok {
return true
}
filtered := filterMappingsForAccount(update.Mapping, *conn.accountID)
if len(filtered) == 0 {
return true
}
connUpdate = &proto.GetMappingUpdateResponse{
Mapping: filtered,
InitialSyncComplete: update.InitialSyncComplete,
}
}
resp := s.perProxyMessage(connUpdate, conn.proxyID)
if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
return true
}
select {
case conn.sendChan <- resp:
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
}
return true
})
}
// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect.
func (s *ProxyServiceServer) ForceDisconnect(proxyID string) {
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
conn.cancel()
s.connectedProxies.Delete(proxyID)
log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy")
}
}
func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping {
var filtered []*proto.ProxyMapping
for _, m := range mappings {
if m.AccountId == accountID {
filtered = append(filtered, m)
}
}
return filtered
}
// GetConnectedProxies returns a list of connected proxy IDs
func (s *ProxyServiceServer) GetConnectedProxies() []string {
var proxies []string
@@ -489,19 +647,25 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
continue
}
conn := connVal.(*proxyConnection)
if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId {
continue
}
if !proxyAcceptsMapping(conn, update) {
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
continue
}
msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil {
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
conn.cancel()
continue
}
select {
case conn.sendChan <- msg:
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
conn.cancel()
}
}
}
@@ -527,7 +691,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
// Returns nil if token generation fails (the proxy should be skipped).
// Returns nil if token generation fails; the caller must disconnect the
// proxy so it can resync via a fresh snapshot on reconnect.
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, mapping := range update.Mapping {
@@ -536,7 +701,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
continue
}
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL())
if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil
@@ -572,6 +737,10 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
@@ -691,6 +860,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
// SendStatusUpdate handles status updates from proxy clients.
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
accountID := req.GetAccountId()
serviceID := req.GetServiceId()
protoStatus := req.GetStatus()
@@ -761,6 +934,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
// CreateProxyPeer handles proxy peer creation with one-time token authentication
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
serviceID := req.GetServiceId()
accountID := req.GetAccountId()
token := req.GetToken()
@@ -815,6 +992,10 @@ func strPtr(s string) *string {
}
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
redirectURL, err := url.Parse(req.GetRedirectUrl())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
@@ -943,21 +1124,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key
services, err := s.serviceManager.GetGlobalServices(ctx)
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
return "", fmt.Errorf("get services: %w", err)
}
var service *rpservice.Service
for _, svc := range services {
if svc.Domain == domain {
service = svc
break
}
}
if service == nil {
return "", fmt.Errorf("service not found for domain: %s", domain)
return "", fmt.Errorf("service not found for domain %s: %w", domain, err)
}
if service.SessionPrivateKey == "" {
@@ -1055,6 +1224,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
return nil, err
}
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
if err != nil {
log.WithFields(log.Fields{
@@ -1138,18 +1311,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get services: %w", err)
}
for _, service := range services {
if service.Domain == domain {
return service, nil
}
}
return nil, fmt.Errorf("service not found for domain: %s", domain)
return s.serviceManager.GetServiceByDomain(ctx, domain)
}
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {

View File

@@ -0,0 +1,29 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsProxyAddressValid(t *testing.T) {
tests := []struct {
name string
addr string
valid bool
}{
{name: "valid domain", addr: "eu.proxy.netbird.io", valid: true},
{name: "valid subdomain", addr: "byop.proxy.example.com", valid: true},
{name: "valid IPv4", addr: "10.0.0.1", valid: true},
{name: "valid IPv4 public", addr: "203.0.113.10", valid: true},
{name: "valid IPv6", addr: "::1", valid: true},
{name: "valid IPv6 full", addr: "2001:db8::1", valid: true},
{name: "empty string", addr: "", valid: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr))
})
}
}

View File

@@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
// Currently tokens are management-wide; AccountID field is reserved for future use.
if !token.IsValid() {
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
}

View File

@@ -53,6 +53,10 @@ func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID,
return nil
}
func (m *mockReverseProxyManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
@@ -91,6 +95,20 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) {
if m.err != nil {
return nil, m.err
}
for _, services := range m.proxiesByAccount {
for _, svc := range services {
if svc.Domain == domain {
return svc, nil
}
}
}
return nil, errors.New("service not found for domain: " + domain)
}
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}

View File

@@ -0,0 +1,174 @@
package grpc
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/management/proto"
)
// recordingStream captures all messages sent via Send so tests can inspect
// batching behaviour without a real gRPC transport.
type recordingStream struct {
grpc.ServerStream
messages []*proto.GetMappingUpdateResponse
}
func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error {
s.messages = append(s.messages, m)
return nil
}
func (s *recordingStream) Context() context.Context { return context.Background() }
func (s *recordingStream) SetHeader(metadata.MD) error { return nil }
func (s *recordingStream) SendHeader(metadata.MD) error { return nil }
func (s *recordingStream) SetTrailer(metadata.MD) {}
func (s *recordingStream) SendMsg(any) error { return nil }
func (s *recordingStream) RecvMsg(any) error { return nil }
// makeServices creates n enabled services assigned to the given cluster.
func makeServices(n int, cluster string) []*rpservice.Service {
services := make([]*rpservice.Service, n)
for i := range n {
services[i] = &rpservice.Service{
ID: fmt.Sprintf("svc-%d", i),
AccountID: "acct-1",
Name: fmt.Sprintf("svc-%d", i),
Domain: fmt.Sprintf("svc-%d.example.com", i),
ProxyCluster: cluster,
Enabled: true,
Targets: []*rpservice.Target{
{TargetType: rpservice.TargetTypeHost, TargetId: "host-1"},
},
}
}
return services
}
func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer {
t.Helper()
s := &ProxyServiceServer{
tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)),
snapshotBatchSize: batchSize,
}
s.SetProxyController(newTestProxyController())
return s
}
func TestSendSnapshot_BatchesMappings(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 7 // 3 + 3 + 1
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
stream: stream,
}
err := s.sendSnapshot(context.Background(), conn)
require.NoError(t, err)
// Expect ceil(7/3) = 3 messages
require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages")
assert.Len(t, stream.messages[0].Mapping, 3)
assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete")
assert.Len(t, stream.messages[1].Mapping, 3)
assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete")
assert.Len(t, stream.messages[2].Mapping, 1)
assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete")
// Verify all service IDs are present exactly once
seen := make(map[string]bool)
for _, msg := range stream.messages {
for _, m := range msg.Mapping {
assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id)
seen[m.Id] = true
}
}
assert.Len(t, seen, totalServices)
}
func TestSendSnapshot_ExactBatchMultiple(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 6 // exactly 2 batches
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 2)
assert.Len(t, stream.messages[0].Mapping, 3)
assert.False(t, stream.messages[0].InitialSyncComplete)
assert.Len(t, stream.messages[1].Mapping, 3)
assert.True(t, stream.messages[1].InitialSyncComplete)
}
func TestSendSnapshot_SingleBatch(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 100
const totalServices = 5
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 1, "all mappings should fit in one batch")
assert.Len(t, stream.messages[0].Mapping, totalServices)
assert.True(t, stream.messages[0].InitialSyncComplete)
}
func TestSendSnapshot_EmptySnapshot(t *testing.T) {
const cluster = "cluster.example.com"
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
s := newSnapshotTestServer(t, 500)
s.serviceManager = mgr
stream := &recordingStream{}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete")
assert.Empty(t, stream.messages[0].Mapping)
assert.True(t, stream.messages[0].InitialSyncComplete)
}

View File

@@ -12,9 +12,12 @@ import (
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -85,11 +88,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
ch := make(chan *proto.GetMappingUpdateResponse, 10)
ctx, cancel := context.WithCancel(context.Background())
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
capabilities: caps,
sendChan: ch,
ctx: ctx,
cancel: cancel,
}
s.connectedProxies.Store(proxyID, conn)
@@ -313,6 +319,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
assert.Contains(t, err.Error(), "invalid state format")
}
func scopedCtx(accountID string) context.Context {
token := &types.ProxyAccessToken{
ID: "token-1",
AccountID: &accountID,
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func globalCtx() context.Context {
token := &types.ProxyAccessToken{
ID: "token-global",
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-1")
assert.NoError(t, err)
}
func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-2")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) {
err := enforceAccountScope(globalCtx(), "acc-1")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "acc-2")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "")
assert.NoError(t, err)
}
func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) {
err := enforceAccountScope(context.Background(), "acc-1")
assert.NoError(t, err)
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background()
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))

View File

@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)
@@ -318,13 +318,17 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error {
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error {
return nil
}
@@ -340,6 +344,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
return nil, nil
}
@@ -348,6 +356,22 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time
return nil
}
func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
return 0, nil
}
func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
return true, nil
}
func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}

View File

@@ -3074,7 +3074,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -144,6 +145,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
}
proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router)
// Register OAuth callback handler for proxy authentication
if proxyGRPCServer != nil {
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)

View File

@@ -7,11 +7,8 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -45,11 +42,6 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
// getAllCountries retrieves a list of all countries
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
@@ -71,11 +63,6 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
// getCitiesByCountry retrieves a list of cities based on the given country code
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) {
@@ -102,27 +89,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
util.WriteJSONObject(r.Context(), w, cities)
}
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
return err
}
accountID, userID := userAuth.AccountId, userAuth.UserId
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
return nil
}
func toCountryResponse(country geolocation.Country) api.Country {
return api.Country{
CountryName: country.CountryName,

View File

@@ -216,6 +216,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
nil,
usersManager,
nil,
nil,
)
proxyService.SetServiceManager(&testServiceManager{store: testStore})
@@ -389,6 +390,10 @@ func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) er
return nil
}
func (m *testServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
@@ -435,6 +440,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
return nil, nil
}

View File

@@ -109,7 +109,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
@@ -238,7 +238,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {

View File

@@ -4495,6 +4495,47 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
return nil
}
func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var tokens []*types.ProxyAccessToken
result := tx.Where("account_id = ?", accountID).Find(&tokens)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error)
}
return tokens, nil
}
func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID)
if err != nil {
return false, err
}
return token.IsValid(), nil
}
func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var token types.ProxyAccessToken
result := tx.Take(&token, idQueryCondition, tokenID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "proxy access token not found")
}
return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error)
}
return &token, nil
}
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
result := s.db.Model(&types.ProxyAccessToken{}).
@@ -5445,7 +5486,7 @@ func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID strin
Model(&proxy.Proxy{}).
Where("id = ? AND session_id = ?", proxyID, sessionID).
Updates(map[string]any{
"status": "disconnected",
"status": proxy.StatusDisconnected,
"disconnected_at": now,
"last_seen": now,
})
@@ -5476,7 +5517,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
if result.RowsAffected == 0 {
p.LastSeen = now
p.ConnectedAt = &now
p.Status = "connected"
p.Status = proxy.StatusConnected
if err := s.db.Create(p).Error; err != nil {
log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err)
}
@@ -5485,13 +5526,15 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
return nil
}
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
// GetActiveProxyClusterAddresses returns the unique cluster addresses of active
// shared proxies (those without an account scope). BYOP cluster addresses are
// excluded; use GetActiveProxyClusterAddressesForAccount to retrieve them.
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
var addresses []string
result := s.db.
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Where("account_id IS NULL AND status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
@@ -5503,13 +5546,75 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
return addresses, nil
}
// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count.
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
var addresses []string
result := s.db.
Model(&proxy.Proxy{}).
Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account")
}
return addresses, nil
}
func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
var p proxy.Proxy
result := s.db.Where("account_id = ?", accountID).Take(&p)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "proxy not found for account")
}
return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error)
}
return &p, nil
}
func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
var count int64
result := s.db.Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count)
if result.Error != nil {
return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error)
}
return count, nil
}
func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
var count int64
result := s.db.
Model(&proxy.Proxy{}).
Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID).
Count(&count)
if result.Error != nil {
return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error)
}
return count > 0, nil
}
func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
result := s.db.
Where("cluster_address = ? AND account_id = ?", clusterAddress, accountID).
Delete(&proxy.Proxy{})
if result.Error != nil {
return status.Errorf(status.Internal, "delete account cluster: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "cluster not found")
}
return nil
}
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
var clusters []proxy.Cluster
result := s.db.Model(&proxy.Proxy{}).
Select("cluster_address as address, COUNT(*) as connected_proxies").
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted").
Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)",
proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID).
Group("cluster_address").
Scan(&clusters)

View File

@@ -114,6 +114,9 @@ type Store interface {
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error)
GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error)
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
@@ -287,11 +290,16 @@ type Store interface {
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
@@ -495,6 +503,9 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
},
func(db *gorm.DB) error {
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
},
}
}

View File

@@ -165,20 +165,6 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
}
// GetClusterSupportsCrowdSec mocks base method.
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
}
// Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper()
@@ -237,6 +223,21 @@ func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
}
// CountProxiesByAccountID mocks base method.
func (m *MockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountProxiesByAccountID", ctx, accountID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountProxiesByAccountID indicates an expected call of CountProxiesByAccountID.
func (mr *MockStoreMockRecorder) CountProxiesByAccountID(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProxiesByAccountID", reflect.TypeOf((*MockStore)(nil).CountProxiesByAccountID), ctx, accountID)
}
// CreateAccessLog mocks base method.
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
m.ctrl.T.Helper()
@@ -575,6 +576,20 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID)
}
// DeleteAccountCluster mocks base method.
func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
}
// DeleteRoute mocks base method.
func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error {
m.ctrl.T.Helper()
@@ -1301,19 +1316,34 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
}
// GetActiveProxyClusters mocks base method.
func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
// GetActiveProxyClusterAddressesForAccount mocks base method.
func (m *MockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx)
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddressesForAccount", ctx, accountID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveProxyClusterAddressesForAccount indicates an expected call of GetActiveProxyClusterAddressesForAccount.
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID)
}
// GetActiveProxyClusters mocks base method.
func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID)
ret0, _ := ret[0].([]proxy.Cluster)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters.
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call {
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID)
}
// GetAllAccounts mocks base method.
@@ -1389,6 +1419,20 @@ func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
}
// GetClusterSupportsCrowdSec mocks base method.
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
}
// GetClusterSupportsCustomPorts mocks base method.
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
@@ -1958,6 +2002,51 @@ func (mr *MockStoreMockRecorder) GetProxyAccessTokenByHashedToken(ctx, lockStren
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken)
}
// GetProxyAccessTokenByID mocks base method.
func (m *MockStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types2.ProxyAccessToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxyAccessTokenByID", ctx, lockStrength, tokenID)
ret0, _ := ret[0].(*types2.ProxyAccessToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProxyAccessTokenByID indicates an expected call of GetProxyAccessTokenByID.
func (mr *MockStoreMockRecorder) GetProxyAccessTokenByID(ctx, lockStrength, tokenID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByID), ctx, lockStrength, tokenID)
}
// GetProxyAccessTokensByAccountID mocks base method.
func (m *MockStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types2.ProxyAccessToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxyAccessTokensByAccountID", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*types2.ProxyAccessToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProxyAccessTokensByAccountID indicates an expected call of GetProxyAccessTokensByAccountID.
func (mr *MockStoreMockRecorder) GetProxyAccessTokensByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokensByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokensByAccountID), ctx, lockStrength, accountID)
}
// GetProxyByAccountID mocks base method.
func (m *MockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxyByAccountID", ctx, accountID)
ret0, _ := ret[0].(*proxy.Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProxyByAccountID indicates an expected call of GetProxyByAccountID.
func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID)
}
// GetResourceGroups mocks base method.
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
m.ctrl.T.Helper()
@@ -2390,6 +2479,21 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
}
// IsClusterAddressConflicting mocks base method.
func (m *MockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClusterAddressConflicting", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsClusterAddressConflicting indicates an expected call of IsClusterAddressConflicting.
func (mr *MockStoreMockRecorder) IsClusterAddressConflicting(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressConflicting", reflect.TypeOf((*MockStore)(nil).IsClusterAddressConflicting), ctx, clusterAddress, accountID)
}
// IsPrimaryAccount mocks base method.
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
m.ctrl.T.Helper()
@@ -2406,6 +2510,21 @@ func (mr *MockStoreMockRecorder) IsPrimaryAccount(ctx, accountID interface{}) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPrimaryAccount", reflect.TypeOf((*MockStore)(nil).IsPrimaryAccount), ctx, accountID)
}
// IsProxyAccessTokenValid mocks base method.
func (m *MockStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsProxyAccessTokenValid", ctx, tokenID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsProxyAccessTokenValid indicates an expected call of IsProxyAccessTokenValid.
func (mr *MockStoreMockRecorder) IsProxyAccessTokenValid(ctx, tokenID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProxyAccessTokenValid", reflect.TypeOf((*MockStore)(nil).IsProxyAccessTokenValid), ctx, tokenID)
}
// ListCustomDomains mocks base method.
func (m *MockStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) {
m.ctrl.T.Helper()

View File

@@ -144,8 +144,11 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// Modify one character in the secret part
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
replacement := "X"
if plainToken[5] == 'X' {
replacement = "Y"
}
modifiedToken := plainToken[:5] + replacement + plainToken[6:]
err = ValidateInviteToken(modifiedToken)
require.Error(t, err)
}

View File

@@ -0,0 +1,409 @@
package proxy
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
grpcstatus "google.golang.org/grpc/status"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/proto"
)
type byopTestSetup struct {
store store.Store
proxyService *nbgrpc.ProxyServiceServer
grpcServer *grpc.Server
grpcAddr string
cleanup func()
accountA string
accountB string
accountAToken types.PlainProxyToken
accountBToken types.PlainProxyToken
accountACluster string
accountBCluster string
}
func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
t.Helper()
ctx := context.Background()
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
accountAID := "byop-account-a"
accountBID := "byop-account-b"
for _, acc := range []*types.Account{
{Id: accountAID, Domain: "a.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
{Id: accountBID, Domain: "b.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
} {
require.NoError(t, testStore.SaveAccount(ctx, acc))
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
clusterA := "byop-a.proxy.test"
clusterB := "byop-b.proxy.test"
services := []*service.Service{
{
ID: "svc-a1", AccountID: accountAID, Name: "App A1",
Domain: "app1." + clusterA, ProxyCluster: clusterA, Enabled: true,
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, Protocol: "http", TargetId: "peer-a1", TargetType: "peer", Enabled: true}},
},
{
ID: "svc-a2", AccountID: accountAID, Name: "App A2",
Domain: "app2." + clusterA, ProxyCluster: clusterA, Enabled: true,
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, Protocol: "http", TargetId: "peer-a2", TargetType: "peer", Enabled: true}},
},
{
ID: "svc-b1", AccountID: accountBID, Name: "App B1",
Domain: "app1." + clusterB, ProxyCluster: clusterB, Enabled: true,
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.3", Port: 8080, Protocol: "http", TargetId: "peer-b1", TargetType: "peer", Enabled: true}},
},
}
for _, svc := range services {
require.NoError(t, testStore.CreateService(ctx, svc))
}
tokenA, err := types.CreateNewProxyAccessToken("byop-token-a", 0, &accountAID, "admin-a")
require.NoError(t, err)
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenA.ProxyAccessToken))
tokenB, err := types.CreateNewProxyAccessToken("byop-token-b", 0, &accountBID, "admin-b")
require.NoError(t, err)
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenB.ProxyAccessToken))
cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore)
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore)
meter := noop.NewMeterProvider().Meter("test")
realProxyManager, err := proxymanager.NewManager(testStore, meter)
require.NoError(t, err)
oidcConfig := nbgrpc.ProxyOIDCConfig{
Issuer: "https://fake-issuer.example.com",
ClientID: "test-client",
HMACKey: []byte("test-hmac-key"),
}
usersManager := users.NewManager(testStore)
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
pkceStore,
oidcConfig,
nil,
usersManager,
realProxyManager,
nil,
)
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
proxyService.SetServiceManager(svcMgr)
proxyController := &testProxyController{}
proxyService.SetProxyController(proxyController)
_, streamInterceptor, authClose := nbgrpc.NewProxyAuthInterceptors(testStore)
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
grpcServer := grpc.NewServer(grpc.StreamInterceptor(streamInterceptor))
proto.RegisterProxyServiceServer(grpcServer, proxyService)
go func() {
if err := grpcServer.Serve(lis); err != nil {
t.Logf("gRPC server error: %v", err)
}
}()
return &byopTestSetup{
store: testStore,
proxyService: proxyService,
grpcServer: grpcServer,
grpcAddr: lis.Addr().String(),
cleanup: func() {
grpcServer.GracefulStop()
authClose()
storeCleanup()
},
accountA: accountAID,
accountB: accountBID,
accountAToken: tokenA.PlainToken,
accountBToken: tokenB.PlainToken,
accountACluster: clusterA,
accountBCluster: clusterB,
}
}
func byopContext(ctx context.Context, token types.PlainProxyToken) context.Context {
md := metadata.Pairs("authorization", "Bearer "+string(token))
return metadata.NewOutgoingContext(ctx, md)
}
func receiveBYOPMappings(t *testing.T, stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
t.Helper()
var mappings []*proto.ProxyMapping
for {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
}
return mappings
}
func TestIntegration_BYOPProxy_ReceivesOnlyAccountServices(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel()
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
mappings := receiveBYOPMappings(t, stream)
assert.Len(t, mappings, 2, "BYOP proxy should receive only account A's 2 services")
for _, m := range mappings {
assert.Equal(t, setup.accountA, m.GetAccountId(), "all mappings should belong to account A")
t.Logf("received mapping: id=%s domain=%s account=%s", m.GetId(), m.GetDomain(), m.GetAccountId())
}
ids := map[string]bool{}
for _, m := range mappings {
ids[m.GetId()] = true
}
assert.True(t, ids["svc-a1"], "should contain svc-a1")
assert.True(t, ids["svc-a2"], "should contain svc-a2")
assert.False(t, ids["svc-b1"], "should NOT contain account B's svc-b1")
}
func TestIntegration_BYOPProxy_AccountBReceivesOnlyItsServices(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
defer cancel()
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-b",
Version: "test-v1",
Address: setup.accountBCluster,
})
require.NoError(t, err)
mappings := receiveBYOPMappings(t, stream)
assert.Len(t, mappings, 1, "BYOP proxy B should receive only 1 service")
assert.Equal(t, "svc-b1", mappings[0].GetId())
assert.Equal(t, setup.accountB, mappings[0].GetAccountId())
}
func TestIntegration_BYOPProxy_MultiplePerAccount(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel1()
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a-first",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
mappings1 := receiveBYOPMappings(t, stream1)
assert.Len(t, mappings1, 2, "first BYOP proxy should receive account A's 2 services")
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a-second",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
mappings2 := receiveBYOPMappings(t, stream2)
assert.Len(t, mappings2, 2, "second BYOP proxy from same account should also receive the 2 services")
for _, m := range mappings2 {
assert.Equal(t, setup.accountA, m.GetAccountId())
}
}
func TestIntegration_BYOPProxy_ClusterAddressConflict(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel1()
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a-cluster",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
_ = receiveBYOPMappings(t, stream1)
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-b-conflict",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
_, err = stream2.Recv()
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.AlreadyExists, st.Code(), "cluster address conflict should return AlreadyExists")
t.Logf("expected rejection: %s", st.Message())
}
func TestIntegration_BYOPProxy_SameProxyReconnects(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
proxyID := "byop-proxy-reconnect"
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
firstMappings := receiveBYOPMappings(t, stream1)
cancel1()
time.Sleep(200 * time.Millisecond)
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
secondMappings := receiveBYOPMappings(t, stream2)
assert.Equal(t, len(firstMappings), len(secondMappings), "reconnect should receive same mappings")
firstIDs := map[string]bool{}
for _, m := range firstMappings {
firstIDs[m.GetId()] = true
}
for _, m := range secondMappings {
assert.True(t, firstIDs[m.GetId()], "mapping %s should be present on reconnect", m.GetId())
}
}
func TestIntegration_BYOPProxy_UnauthenticatedRejected(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "no-auth-proxy",
Version: "test-v1",
Address: "some.cluster.io",
})
require.NoError(t, err)
_, err = stream.Recv()
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.Unauthenticated, st.Code())
}

View File

@@ -6,6 +6,7 @@ import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
@@ -140,6 +141,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
nil,
usersManager,
proxyManager,
nil,
)
// Use store-backed service manager
@@ -201,8 +203,8 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
// testProxyManager is a mock implementation of proxy.Manager for testing.
type testProxyManager struct{}
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: nbproxy.StatusConnected}, nil
}
func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error {
@@ -217,6 +219,10 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin
return nil, nil
}
func (m *testProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
return nil, nil
}
func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) {
return nil, nil
}
@@ -237,6 +243,22 @@ func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) erro
return nil
}
func (m *testProxyManager) GetAccountProxy(_ context.Context, accountID string) (*nbproxy.Proxy, error) {
return nil, fmt.Errorf("proxy not found for account %s", accountID)
}
func (m *testProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
return 0, nil
}
func (m *testProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
return true, nil
}
func (m *testProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error {
return nil
}
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
type testProxyController struct{}
@@ -290,6 +312,10 @@ func (m *storeBackedServiceManager) DeleteService(ctx context.Context, accountID
return nil
}
func (m *storeBackedServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
return nil
}
func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return nil
}
@@ -336,6 +362,10 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _,
func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {}
func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
return nil, nil
}
@@ -364,14 +394,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
})
require.NoError(t, err)
// Receive all mappings from the snapshot - server sends each mapping individually
mappingsByID := make(map[string]*proto.ProxyMapping)
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
for _, m := range msg.GetMapping() {
mappingsByID[m.GetId()] = m
}
if msg.GetInitialSyncComplete() {
break
}
}
// Should receive 2 mappings total
@@ -411,12 +443,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
})
require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
mappings := make([]*proto.ProxyMapping, 0)
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
}
// Should receive the 2 mappings matching the cluster
@@ -440,13 +474,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
clusterAddress := "test.proxy.io"
proxyID := "test-proxy-reconnect"
// Helper to receive all mappings from a stream
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
var mappings []*proto.ProxyMapping
for i := 0; i < count; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
}
return mappings
}
@@ -460,7 +496,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
})
require.NoError(t, err)
firstMappings := receiveMappings(stream1, 2)
firstMappings := receiveMappings(stream1)
cancel1()
time.Sleep(100 * time.Millisecond)
@@ -476,7 +512,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
})
require.NoError(t, err)
secondMappings := receiveMappings(stream2, 2)
secondMappings := receiveMappings(stream2)
// Should receive the same mappings
assert.Equal(t, len(firstMappings), len(secondMappings),
@@ -542,12 +578,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
}
}
// Helper to receive and apply all mappings
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
applyMappings(msg.GetMapping())
if msg.GetInitialSyncComplete() {
break
}
}
}
@@ -636,12 +674,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
})
require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
count := 0
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
count += len(msg.GetMapping())
if msg.GetInitialSyncComplete() {
break
}
}
mu.Lock()
@@ -681,9 +721,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err := stream1.Recv()
for {
msg, err := stream1.Recv()
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
@@ -699,9 +742,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err := stream2.Recv()
for {
msg, err := stream2.Recv()
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
cancel1()

View File

@@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
operation := func() error {
s.Logger.Debug("connecting to management mapping stream")
initialSyncDone = false
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false)
}
@@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
return ctx.Err()
}
var snapshotIDs map[types.ServiceID]struct{}
if !*initialSyncDone {
snapshotIDs = make(map[types.ServiceID]struct{})
}
for {
// Check for context completion to gracefully shutdown.
select {
@@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
if !*initialSyncDone && msg.GetInitialSyncComplete() {
if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
if !*initialSyncDone {
for _, m := range msg.GetMapping() {
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
}
if msg.GetInitialSyncComplete() {
s.reconcileSnapshot(ctx, snapshotIDs)
snapshotIDs = nil
if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
}
*initialSyncDone = true
s.Logger.Info("Initial mapping sync complete")
}
*initialSyncDone = true
s.Logger.Info("Initial mapping sync complete")
}
}
}
}
// reconcileSnapshot removes local mappings that are absent from the snapshot.
// This ensures services deleted while the proxy was disconnected get cleaned up.
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
s.portMu.RLock()
var stale []*proto.ProxyMapping
for svcID, mapping := range s.lastMappings {
if _, ok := snapshotIDs[svcID]; !ok {
stale = append(stale, mapping)
}
}
s.portMu.RUnlock()
for _, mapping := range stale {
s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(),
"domain": mapping.GetDomain(),
}).Info("Removing stale mapping absent from snapshot")
s.removeMapping(ctx, mapping)
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
for _, mapping := range mappings {
s.Logger.WithFields(log.Fields{

View File

@@ -0,0 +1,227 @@
package proxy
import (
"context"
"io"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot
// so we can verify it without triggering removeMapping (which requires full
// server wiring). This keeps the test focused on the detection algorithm.
func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID {
var stale []types.ServiceID
for svcID := range lastMappings {
if _, ok := snapshotIDs[svcID]; !ok {
stale = append(stale, svcID)
}
}
return stale
}
// TestStaleDetection_PartialOverlap verifies that only services absent from
// the snapshot are flagged as stale.
func TestStaleDetection_PartialOverlap(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
"svc-stale-a": {Id: "svc-stale-a"},
"svc-stale-b": {Id: "svc-stale-b"},
}
snapshot := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
"svc-3": {}, // new service, not in local
}
stale := collectStaleIDs(local, snapshot)
assert.Len(t, stale, 2)
staleSet := make(map[types.ServiceID]struct{})
for _, id := range stale {
staleSet[id] = struct{}{}
}
assert.Contains(t, staleSet, types.ServiceID("svc-stale-a"))
assert.Contains(t, staleSet, types.ServiceID("svc-stale-b"))
}
// TestStaleDetection_AllStale verifies an empty snapshot flags everything.
func TestStaleDetection_AllStale(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
}
stale := collectStaleIDs(local, map[types.ServiceID]struct{}{})
assert.Len(t, stale, 2)
}
// TestStaleDetection_NoneStale verifies full overlap produces no stale entries.
func TestStaleDetection_NoneStale(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
}
snapshot := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
}
stale := collectStaleIDs(local, snapshot)
assert.Empty(t, stale)
}
// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty.
func TestStaleDetection_EmptyLocal(t *testing.T) {
stale := collectStaleIDs(
map[types.ServiceID]*proto.ProxyMapping{},
map[types.ServiceID]struct{}{"svc-1": {}},
)
assert.Empty(t, stale)
}
// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all
// local mappings are present in the snapshot (removeMapping is never called).
func TestReconcileSnapshot_NoStale(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"}
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"}
snapshotIDs := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
}
// This should not panic — no stale entries means removeMapping is never called.
s.reconcileSnapshot(context.Background(), snapshotIDs)
assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot")
}
// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with
// no local mappings.
func TestReconcileSnapshot_EmptyLocal(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}})
assert.Empty(t, s.lastMappings)
}
// --- handleMappingStream tests for batched snapshot ID accumulation ---
// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is
// marked done only after the final InitialSyncComplete message, even when
// the snapshot arrives in multiple batches.
func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // batch 1: no sync-complete
{}, // batch 2: no sync-complete
{InitialSyncComplete: true}, // batch 3: sync done
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "sync should be marked done after final batch")
}
// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages
// arriving after InitialSyncComplete do not trigger a second reconciliation.
func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
// Simulate state left over from a previous sync.
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"}
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // post-sync empty message — must not reconcile
},
}
syncDone := true // sync already completed in a previous stream
err := s.handleMappingStream(context.Background(), stream, &syncDone)
require.NoError(t, err)
assert.Len(t, s.lastMappings, 2,
"post-sync messages must not trigger reconciliation — all entries should survive")
}
// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the
// stream closes before sync completes, no reconciliation occurs.
func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
stream := &mockMappingStream{} // no messages → immediate EOF
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
_, hasStale := s.lastMappings["svc-stale"]
assert.True(t, hasStale, "stale mapping should remain when sync never completed")
}
// mockErrRecvStream returns an error on the second Recv to verify
// handleMappingStream returns without completing sync.
type mockErrRecvStream struct {
mockMappingStream
calls int
}
func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) {
m.calls++
if m.calls == 1 {
return &proto.GetMappingUpdateResponse{}, nil
}
return nil, io.ErrUnexpectedEOF
}
func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
syncDone := false
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
assert.Error(t, err)
assert.False(t, syncDone)
_, hasStale := s.lastMappings["svc-stale"]
assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error")
}

View File

@@ -3327,10 +3327,64 @@ components:
example: false
required:
- enabled
ProxyTokenRequest:
type: object
properties:
name:
type: string
description: Human-readable token name
example: "my-proxy-token"
expires_in:
type: integer
minimum: 0
description: Token expiration in seconds (0 = never expires)
example: 0
required:
- name
ProxyToken:
type: object
properties:
id:
type: string
name:
type: string
expires_at:
type: string
format: date-time
created_at:
type: string
format: date-time
last_used:
type: string
format: date-time
revoked:
type: boolean
required:
- id
- name
- created_at
- revoked
ProxyTokenCreated:
type: object
description: Returned on creation — plain_token is shown only once
allOf:
- $ref: '#/components/schemas/ProxyToken'
- type: object
properties:
plain_token:
type: string
description: The plain text token (shown only once)
example: "nbx_abc123..."
required:
- plain_token
ProxyCluster:
type: object
description: A proxy cluster represents a group of proxy nodes serving the same address
properties:
id:
type: string
description: Unique identifier of a proxy in this cluster
example: "chlfq4q5r8kc73b0qjpg"
address:
type: string
description: Cluster address used for CNAME targets
@@ -3339,9 +3393,15 @@ components:
type: integer
description: Number of proxy nodes connected in this cluster
example: 3
self_hosted:
type: boolean
description: Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
example: false
required:
- id
- address
- connected_proxies
- self_hosted
ReverseProxyDomainType:
type: string
description: Type of Reverse Proxy Domain
@@ -11347,6 +11407,111 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/reverse-proxies/clusters/{clusterAddress}:
delete:
summary: Delete a self-hosted proxy cluster
description: Removes all self-hosted (BYOP) proxy registrations for the given cluster address owned by the account.
tags: [ Services ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: clusterAddress
required: true
schema:
type: string
description: The address of the proxy cluster
responses:
'200':
description: Proxy cluster deleted successfully
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
/api/reverse-proxies/proxy-tokens:
get:
summary: List Proxy Tokens
description: Returns all proxy access tokens for the account
tags: [ Self-Hosted Proxies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of proxy tokens
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/ProxyToken'
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a Proxy Token
description: Generate an account-scoped proxy access token for self-hosted proxy registration
tags: [ Self-Hosted Proxies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/ProxyTokenRequest'
responses:
'200':
description: Proxy token created (plain token shown once)
content:
application/json:
schema:
$ref: '#/components/schemas/ProxyTokenCreated'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/reverse-proxies/proxy-tokens/{tokenId}:
delete:
summary: Revoke a Proxy Token
description: Revoke an account-scoped proxy access token
tags: [ Self-Hosted Proxies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: tokenId
required: true
schema:
type: string
description: The unique identifier of the proxy token
responses:
'200':
description: Token revoked
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
/api/reverse-proxies/services:
get:
summary: List all Services

View File

@@ -3764,11 +3764,49 @@ type ProxyAccessLogsResponse struct {
// ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address
type ProxyCluster struct {
// Id Unique identifier of a proxy in this cluster
Id string `json:"id"`
// Address Cluster address used for CNAME targets
Address string `json:"address"`
// ConnectedProxies Number of proxy nodes connected in this cluster
ConnectedProxies int `json:"connected_proxies"`
// SelfHosted Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
SelfHosted bool `json:"self_hosted"`
}
// ProxyToken defines model for ProxyToken.
type ProxyToken struct {
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Id string `json:"id"`
LastUsed *time.Time `json:"last_used,omitempty"`
Name string `json:"name"`
Revoked bool `json:"revoked"`
}
// ProxyTokenCreated defines model for ProxyTokenCreated.
type ProxyTokenCreated struct {
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Id string `json:"id"`
LastUsed *time.Time `json:"last_used,omitempty"`
Name string `json:"name"`
// PlainToken The plain text token (shown only once)
PlainToken string `json:"plain_token"`
Revoked bool `json:"revoked"`
}
// ProxyTokenRequest defines model for ProxyTokenRequest.
type ProxyTokenRequest struct {
// ExpiresIn Token expiration in seconds (0 = never expires)
ExpiresIn *int `json:"expires_in,omitempty"`
// Name Human-readable token name
Name string `json:"name"`
}
// Resource defines model for Resource.
@@ -5139,6 +5177,9 @@ type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate
// PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType.
type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest
// PostApiReverseProxiesProxyTokensJSONRequestBody defines body for PostApiReverseProxiesProxyTokens for application/json ContentType.
type PostApiReverseProxiesProxyTokensJSONRequestBody = ProxyTokenRequest
// PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType.
type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest