Compare commits

..

3 Commits

Author SHA1 Message Date
pascal
e8156ecbb6 add cert 2026-03-04 17:30:21 +01:00
pascal
e14ddaad57 expose fileserver 2026-03-04 15:37:23 +01:00
pascal
65e627febc add static file download 2026-03-04 15:26:19 +01:00
60 changed files with 1084 additions and 4411 deletions

View File

@@ -1,14 +0,0 @@
blank_issues_enabled: true
contact_links:
- name: Community Support
url: https://forum.netbird.io/
about: Community support forum
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact us for support
- name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -1,51 +0,0 @@
name: PR Title Check
on:
pull_request:
types: [opened, edited, synchronize, reopened]
jobs:
check-title:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;
const allowedTags = [
'management',
'client',
'signal',
'proxy',
'relay',
'misc',
'infrastructure',
'self-hosted',
'doc',
];
const pattern = /^\[([^\]]+)\]\s+.+/;
const match = title.match(pattern);
if (!match) {
core.setFailed(
`PR title must start with a tag in brackets.\n` +
`Example: [client] fix something\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
const invalid = tags.filter(t => !allowedTags.includes(t));
if (invalid.length > 0) {
core.setFailed(
`Invalid tag(s): ${invalid.join(', ')}\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
console.log(`Valid PR title tags: [${tags.join(', ')}]`);

View File

@@ -849,26 +849,14 @@ func (s *Server) cleanupConnection() error {
if s.actCancel == nil {
return ErrServiceNotUp
}
// Capture the engine reference before cancelling the context.
// After actCancel(), the connectWithRetryRuns goroutine wakes up
// and sets connectClient.engine = nil, causing connectClient.Stop()
// to skip the engine shutdown entirely.
var engine *internal.Engine
if s.connectClient != nil {
engine = s.connectClient.Engine()
}
s.actCancel()
if s.connectClient == nil {
return nil
}
if engine != nil {
if err := engine.Stop(); err != nil {
return err
}
if err := s.connectClient.Stop(); err != nil {
return err
}
s.connectClient = nil

View File

@@ -46,10 +46,8 @@ const (
cmdSFTP = "<sftp>"
cmdNonInteractive = "<idle>"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server.
// Set to 10 minutes to accommodate identity providers like Azure Entra ID
// that backdate the iat claim by up to 5 minutes.
DefaultJWTMaxTokenAge = 10 * 60
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
)
var (

View File

@@ -493,6 +493,9 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
mgmt := cfg.Management
dnsDomain := mgmt.DnsDomain
singleAccModeDomain := dnsDomain
// Extract port from listen address
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
if err != nil {
@@ -504,9 +507,8 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
mgmtSrv := mgmtServer.NewServer(
&mgmtServer.Config{
NbConfig: mgmtConfig,
DNSDomain: "",
MgmtSingleAccModeDomain: "",
AutoResolveDomains: true,
DNSDomain: dnsDomain,
MgmtSingleAccModeDomain: singleAccModeDomain,
MgmtPort: mgmtPort,
MgmtMetricsPort: cfg.Server.MetricsPort,
DisableMetrics: mgmt.DisableAnonymousMetrics,

View File

@@ -24,8 +24,6 @@ type AccessLogEntry struct {
Reason string
UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"`
BytesUpload int64 `gorm:"index"`
BytesDownload int64 `gorm:"index"`
}
// FromProto creates an AccessLogEntry from a proto.AccessLog
@@ -41,8 +39,6 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
a.UserId = serviceLog.GetUserId()
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
a.AccountID = serviceLog.GetAccountId()
a.BytesUpload = serviceLog.GetBytesUpload()
a.BytesDownload = serviceLog.GetBytesDownload()
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
if ip, err := netip.ParseAddr(sourceIP); err == nil {
@@ -105,7 +101,5 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
AuthMethodUsed: authMethod,
CountryCode: countryCode,
CityName: cityName,
BytesUpload: a.BytesUpload,
BytesDownload: a.BytesDownload,
}
}

View File

@@ -15,12 +15,3 @@ type Domain struct {
Type Type `gorm:"-"`
Validated bool
}
// EventMeta returns activity event metadata for a domain
func (d *Domain) EventMeta() map[string]any {
return map[string]any{
"domain": d.Domain,
"target_cluster": d.TargetCluster,
"validated": d.Validated,
}
}

View File

@@ -9,8 +9,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -38,16 +36,16 @@ type Manager struct {
validator domain.Validator
proxyManager proxyManager
permissionsManager permissions.Manager
accountManager account.Manager
}
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
return Manager{
store: store,
proxyManager: proxyMgr,
validator: domain.Validator{Resolver: net.DefaultResolver},
store: store,
proxyManager: proxyMgr,
validator: domain.Validator{
Resolver: net.DefaultResolver,
},
permissionsManager: permissionsManager,
accountManager: accountManager,
}
}
@@ -138,9 +136,6 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
if err != nil {
return d, fmt.Errorf("create domain in store: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, d.ID, accountID, activity.DomainAdded, d.EventMeta())
return d, nil
}
@@ -153,18 +148,10 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
return status.NewPermissionDeniedError()
}
d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
if err != nil {
return fmt.Errorf("get domain from store: %w", err)
}
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
// TODO: check for "no records" type error. Because that is a success condition.
return fmt.Errorf("delete domain from store: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, domainID, accountID, activity.DomainDeleted, d.EventMeta())
return nil
}
@@ -231,8 +218,6 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
}).WithError(err).Error("update custom domain in store")
return
}
m.accountManager.StoreEvent(context.Background(), userID, domainID, accountID, activity.DomainValidated, d.EventMeta())
} else {
log.WithFields(log.Fields{
"accountID": accountID,

View File

@@ -73,10 +73,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
}
service := new(rpservice.Service)
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
@@ -135,10 +132,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
service := new(rpservice.Service)
service.ID = serviceID
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)

View File

@@ -2,7 +2,7 @@ package manager
import (
"context"
"math/rand/v2"
"sync"
"time"
"github.com/netbirdio/netbird/shared/management/status"
@@ -13,20 +13,108 @@ const (
exposeTTL = 90 * time.Second
exposeReapInterval = 30 * time.Second
maxExposesPerPeer = 10
exposeReapBatch = 100
)
type exposeReaper struct {
manager *Manager
type trackedExpose struct {
mu sync.Mutex
domain string
accountID string
peerID string
lastRenewed time.Time
expiring bool
}
// StartExposeReaper starts a background goroutine that reaps expired ephemeral services from the DB.
func (r *exposeReaper) StartExposeReaper(ctx context.Context) {
go func() {
// start with a random delay
rn := rand.IntN(10)
time.Sleep(time.Duration(rn) * time.Second)
type exposeTracker struct {
activeExposes sync.Map
exposeCreateMu sync.Mutex
manager *Manager
}
func exposeKey(peerID, domain string) string {
return peerID + ":" + domain
}
// TrackExposeIfAllowed atomically checks the per-peer limit and registers a new
// active expose session under the same lock. Returns (true, false) if the expose
// was already tracked (duplicate), (false, true) if tracking succeeded, and
// (false, false) if the peer has reached the limit.
func (t *exposeTracker) TrackExposeIfAllowed(peerID, domain, accountID string) (alreadyTracked, ok bool) {
t.exposeCreateMu.Lock()
defer t.exposeCreateMu.Unlock()
key := exposeKey(peerID, domain)
_, loaded := t.activeExposes.LoadOrStore(key, &trackedExpose{
domain: domain,
accountID: accountID,
peerID: peerID,
lastRenewed: time.Now(),
})
if loaded {
return true, false
}
if t.CountPeerExposes(peerID) > maxExposesPerPeer {
t.activeExposes.Delete(key)
return false, false
}
return false, true
}
// UntrackExpose removes an active expose session from tracking.
func (t *exposeTracker) UntrackExpose(peerID, domain string) {
t.activeExposes.Delete(exposeKey(peerID, domain))
}
// CountPeerExposes returns the number of active expose sessions for a peer.
func (t *exposeTracker) CountPeerExposes(peerID string) int {
count := 0
t.activeExposes.Range(func(_, val any) bool {
if expose := val.(*trackedExpose); expose.peerID == peerID {
count++
}
return true
})
return count
}
// MaxExposesPerPeer returns the maximum number of concurrent exposes allowed per peer.
func (t *exposeTracker) MaxExposesPerPeer() int {
return maxExposesPerPeer
}
// RenewTrackedExpose updates the in-memory lastRenewed timestamp for a tracked expose.
// Returns false if the expose is not tracked or is being reaped.
func (t *exposeTracker) RenewTrackedExpose(peerID, domain string) bool {
key := exposeKey(peerID, domain)
val, ok := t.activeExposes.Load(key)
if !ok {
return false
}
expose := val.(*trackedExpose)
expose.mu.Lock()
if expose.expiring {
expose.mu.Unlock()
return false
}
expose.lastRenewed = time.Now()
expose.mu.Unlock()
return true
}
// StopTrackedExpose removes an active expose session from tracking.
// Returns false if the expose was not tracked.
func (t *exposeTracker) StopTrackedExpose(peerID, domain string) bool {
key := exposeKey(peerID, domain)
_, ok := t.activeExposes.LoadAndDelete(key)
return ok
}
// StartExposeReaper starts a background goroutine that reaps expired expose sessions.
func (t *exposeTracker) StartExposeReaper(ctx context.Context) {
go func() {
ticker := time.NewTicker(exposeReapInterval)
defer ticker.Stop()
@@ -35,31 +123,41 @@ func (r *exposeReaper) StartExposeReaper(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
r.reapExpiredExposes(ctx)
t.reapExpiredExposes()
}
}
}()
}
func (r *exposeReaper) reapExpiredExposes(ctx context.Context) {
expired, err := r.manager.store.GetExpiredEphemeralServices(ctx, exposeTTL, exposeReapBatch)
if err != nil {
log.Errorf("failed to get expired ephemeral services: %v", err)
return
}
func (t *exposeTracker) reapExpiredExposes() {
t.activeExposes.Range(func(key, val any) bool {
expose := val.(*trackedExpose)
expose.mu.Lock()
expired := time.Since(expose.lastRenewed) > exposeTTL
if expired {
expose.expiring = true
}
expose.mu.Unlock()
for _, svc := range expired {
log.Infof("reaping expired expose session for peer %s, domain %s", svc.SourcePeer, svc.Domain)
err := r.manager.deleteExpiredPeerService(ctx, svc.AccountID, svc.SourcePeer, svc.ID)
if err == nil {
continue
if !expired {
return true
}
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
log.Debugf("service %s was already deleted by another instance", svc.Domain)
} else {
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", svc.Domain, err)
log.Infof("reaping expired expose session for peer %s, domain %s", expose.peerID, expose.domain)
err := t.manager.deleteServiceFromPeer(context.Background(), expose.accountID, expose.peerID, expose.domain, true)
s, _ := status.FromError(err)
switch {
case err == nil:
t.activeExposes.Delete(key)
case s.ErrorType == status.NotFound:
log.Debugf("service %s was already deleted", expose.domain)
default:
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", expose.domain, err)
}
}
return true
})
}

View File

@@ -10,62 +10,184 @@ import (
"github.com/stretchr/testify/require"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/store"
)
func TestExposeKey(t *testing.T) {
assert.Equal(t, "peer1:example.com", exposeKey("peer1", "example.com"))
assert.Equal(t, "peer2:other.com", exposeKey("peer2", "other.com"))
assert.NotEqual(t, exposeKey("peer1", "a.com"), exposeKey("peer1", "b.com"))
}
func TestTrackExposeIfAllowed(t *testing.T) {
t.Run("first track succeeds", func(t *testing.T) {
tracker := &exposeTracker{}
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.False(t, alreadyTracked, "first track should not be duplicate")
assert.True(t, ok, "first track should be allowed")
})
t.Run("duplicate track detected", func(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.True(t, alreadyTracked, "second track should be duplicate")
assert.False(t, ok)
})
t.Run("rejects when at limit", func(t *testing.T) {
tracker := &exposeTracker{}
for i := range maxExposesPerPeer {
_, ok := tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
assert.True(t, ok, "track %d should be allowed", i)
}
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "over-limit.com", "acct1")
assert.False(t, alreadyTracked)
assert.False(t, ok, "should reject when at limit")
})
t.Run("other peer unaffected by limit", func(t *testing.T) {
tracker := &exposeTracker{}
for i := range maxExposesPerPeer {
tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
}
_, ok := tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
assert.True(t, ok, "other peer should still be within limit")
})
}
func TestUntrackExpose(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.Equal(t, 1, tracker.CountPeerExposes("peer1"))
tracker.UntrackExpose("peer1", "a.com")
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
}
func TestCountPeerExposes(t *testing.T) {
tracker := &exposeTracker{}
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
tracker.TrackExposeIfAllowed("peer1", "b.com", "acct1")
tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
assert.Equal(t, 2, tracker.CountPeerExposes("peer1"), "peer1 should have 2 exposes")
assert.Equal(t, 1, tracker.CountPeerExposes("peer2"), "peer2 should have 1 expose")
assert.Equal(t, 0, tracker.CountPeerExposes("peer3"), "peer3 should have 0 exposes")
}
func TestMaxExposesPerPeer(t *testing.T) {
tracker := &exposeTracker{}
assert.Equal(t, maxExposesPerPeer, tracker.MaxExposesPerPeer())
}
func TestRenewTrackedExpose(t *testing.T) {
tracker := &exposeTracker{}
found := tracker.RenewTrackedExpose("peer1", "a.com")
assert.False(t, found, "should not find untracked expose")
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
found = tracker.RenewTrackedExpose("peer1", "a.com")
assert.True(t, found, "should find tracked expose")
}
func TestRenewTrackedExpose_RejectsExpiring(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
// Simulate reaper marking the expose as expiring
key := exposeKey("peer1", "a.com")
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.expiring = true
expose.mu.Unlock()
found := tracker.RenewTrackedExpose("peer1", "a.com")
assert.False(t, found, "should reject renewal when expiring")
}
func TestReapExpiredExposes(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
// Manually expire the service by backdating meta_last_renewed_at
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Manually expire the tracked entry
key := exposeKey(testPeerID, resp.Domain)
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
// Create a non-expired service
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8081,
Protocol: "http",
// Add an active (non-expired) tracking entry
tracker.activeExposes.Store(exposeKey("peer1", "active.com"), &trackedExpose{
domain: "active.com",
accountID: testAccountID,
peerID: "peer1",
lastRenewed: time.Now(),
})
require.NoError(t, err)
mgr.exposeReaper.reapExpiredExposes(ctx)
tracker.reapExpiredExposes()
// Expired service should be deleted
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "expired service should be deleted")
_, exists := tracker.activeExposes.Load(key)
assert.False(t, exists, "expired expose should be removed")
// Non-expired service should remain
_, err = testStore.GetServiceByDomain(ctx, resp2.Domain)
require.NoError(t, err, "active service should remain")
_, exists = tracker.activeExposes.Load(exposeKey("peer1", "active.com"))
assert.True(t, exists, "active expose should remain")
}
func TestReapAlreadyDeletedService(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
key := exposeKey(testPeerID, resp.Domain)
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
// Delete the service before reaping
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
// Expire it
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
// Reaping should handle the already-deleted service gracefully
mgr.exposeReaper.reapExpiredExposes(ctx)
// Renew should succeed before reaping
assert.True(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should succeed before reaper runs")
// Re-expire and reap
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
tracker.reapExpiredExposes()
// Entry is deleted, renew returns false
assert.False(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should fail after reap")
}
func TestConcurrentReapAndRenew(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
func TestConcurrentTrackAndCount(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
for i := range 5 {
@@ -76,133 +198,59 @@ func TestConcurrentReapAndRenew(t *testing.T) {
require.NoError(t, err)
}
// Expire all services
services, err := testStore.GetAccountServices(ctx, store.LockingStrengthNone, testAccountID)
require.NoError(t, err)
for _, svc := range services {
if svc.Source == rpservice.SourceEphemeral {
expireEphemeralService(t, testStore, testAccountID, svc.Domain)
}
}
// Manually expire all tracked entries
tracker.activeExposes.Range(func(_, val any) bool {
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
return true
})
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
mgr.exposeReaper.reapExpiredExposes(ctx)
tracker.reapExpiredExposes()
}()
go func() {
defer wg.Done()
_, _ = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
tracker.CountPeerExposes(testPeerID)
}()
wg.Wait()
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count, "all expired services should be reaped")
assert.Equal(t, 0, tracker.CountPeerExposes(testPeerID), "all expired exposes should be reaped")
}
func TestRenewEphemeralService(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
t.Run("renew succeeds for active service", func(t *testing.T) {
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8082,
Protocol: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
})
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
require.Error(t, err)
assert.Contains(t, err.Error(), "no active expose session")
})
}
func TestCountAndExistsEphemeralServices(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8083,
Protocol: "http",
})
require.NoError(t, err)
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
exists, err := mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
assert.True(t, exists, "service should exist")
exists, err = mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, "no-such.domain")
require.NoError(t, err)
assert.False(t, exists, "non-existent service should not exist")
}
func TestMaxExposesPerPeerEnforced(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
ctx := context.Background()
for i := range maxExposesPerPeer {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8090 + i,
Protocol: "http",
})
require.NoError(t, err, "expose %d should succeed", i)
func TestTrackedExposeMutexProtectsLastRenewed(t *testing.T) {
expose := &trackedExpose{
lastRenewed: time.Now().Add(-1 * time.Hour),
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9999,
Protocol: "http",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
}
func TestReapSkipsRenewedService(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8086,
Protocol: "http",
})
require.NoError(t, err)
// Expire the service
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Renew it before the reaper runs
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
// Reaper should skip it because the re-check sees a fresh timestamp
mgr.exposeReaper.reapExpiredExposes(ctx)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err, "renewed service should survive reaping")
}
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
t.Helper()
svc, err := s.GetServiceByDomain(context.Background(), domain)
require.NoError(t, err)
expired := time.Now().Add(-2 * exposeTTL)
svc.Meta.LastRenewedAt = &expired
err = s.UpdateService(context.Background(), svc)
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for range 100 {
expose.mu.Lock()
expose.lastRenewed = time.Now()
expose.mu.Unlock()
}
}()
go func() {
defer wg.Done()
for range 100 {
expose.mu.Lock()
_ = time.Since(expose.lastRenewed)
expose.mu.Unlock()
}
}()
wg.Wait()
expose.mu.Lock()
require.False(t, expose.lastRenewed.IsZero(), "lastRenewed should not be zero after concurrent access")
expose.mu.Unlock()
}

View File

@@ -37,7 +37,7 @@ type Manager struct {
permissionsManager permissions.Manager
proxyController proxy.Controller
clusterDeriver ClusterDeriver
exposeReaper *exposeReaper
exposeTracker *exposeTracker
}
// NewManager creates a new service manager.
@@ -49,13 +49,13 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
proxyController: proxyController,
clusterDeriver: clusterDeriver,
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
mgr.exposeTracker = &exposeTracker{manager: mgr}
return mgr
}
// StartExposeReaper starts the background goroutine that reaps expired ephemeral services.
// StartExposeReaper delegates to the expose tracker.
func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeReaper.StartExposeReaper(ctx)
m.exposeTracker.StartExposeReaper(ctx)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
@@ -199,7 +199,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
return err
}
@@ -215,54 +215,8 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, servi
})
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Lock the peer row to serialize concurrent creates for the same peer.
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
// table returns no rows and acquires no locks, allowing concurrent inserts
// to bypass the per-peer limit.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
return nil
})
}
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, domain)
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
@@ -271,7 +225,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
}
if existingService != nil && existingService.ID != excludeServiceID {
return status.Errorf(status.AlreadyExists, "domain already taken")
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
}
return nil
@@ -352,7 +306,7 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
return err
}
@@ -458,6 +412,10 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
return err
}
if s.Source == service.SourceEphemeral {
m.exposeTracker.UntrackExpose(s.SourcePeer, s.Domain)
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
@@ -499,6 +457,9 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
oidcCfg := m.proxyController.GetOIDCValidationConfig()
for _, svc := range services {
if svc.Source == service.SourceEphemeral {
m.exposeTracker.UntrackExpose(svc.SourcePeer, svc.Domain)
}
m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
}
@@ -720,15 +681,28 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
return nil, err
}
svc.SourcePeer = peerID
now := time.Now()
svc.Meta.LastRenewedAt = &now
svc.SourcePeer = peerID
if err := m.persistNewEphemeralService(ctx, accountID, peerID, svc); err != nil {
if err := m.persistNewService(ctx, accountID, svc); err != nil {
return nil, err
}
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, svc.Domain, accountID)
if alreadyTracked {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", svc.Domain, err)
}
return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
if !allowed {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", svc.Domain, err)
}
return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta)
@@ -774,17 +748,26 @@ func (m *Manager) buildRandomDomain(name string) (string, error) {
return domain, nil
}
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
// RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session.
// Returns an error if the expose is not actively tracked.
func (m *Manager) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error {
if !m.exposeTracker.RenewTrackedExpose(peerID, domain) {
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
}
return nil
}
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
// StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service.
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
return err
}
if !m.exposeTracker.StopTrackedExpose(peerID, domain) {
log.WithContext(ctx).Warnf("expose tracker entry for domain %s already removed; service was deleted", domain)
}
return nil
}
@@ -805,7 +788,7 @@ func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID,
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, domain)
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
return nil, err
}
@@ -865,57 +848,6 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
return nil
}
// deleteExpiredPeerService deletes an ephemeral service by ID after re-checking
// that it is still expired under a row lock. This prevents deleting a service
// that was renewed between the batch query and this delete, and ensures only one
// management instance processes the deletion
func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerID, serviceID string) error {
var svc *service.Service
deleted := false
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if svc.Source != service.SourceEphemeral || svc.SourcePeer != peerID {
return status.Errorf(status.PermissionDenied, "service does not match expected ephemeral owner")
}
if svc.Meta.LastRenewedAt != nil && time.Since(*svc.Meta.LastRenewedAt) <= exposeTTL {
return nil
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
deleted = true
return nil
})
if err != nil {
return err
}
if !deleted {
return nil
}
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
peer = nil
}
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any {
if peer == nil {
return meta

View File

@@ -72,6 +72,7 @@ func TestInitializeServiceForCreate(t *testing.T) {
func TestCheckDomainAvailable(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
tests := []struct {
name string
@@ -87,7 +88,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, "available.com").
GetServiceByDomain(ctx, accountID, "available.com").
Return(nil, status.Errorf(status.NotFound, "not found"))
},
expectedError: false,
@@ -98,7 +99,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, "exists.com").
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
},
expectedError: true,
@@ -110,7 +111,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "service-123",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, "exists.com").
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: false,
@@ -121,7 +122,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "service-456",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, "exists.com").
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: true,
@@ -133,7 +134,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, "error.com").
GetServiceByDomain(ctx, accountID, "error.com").
Return(nil, errors.New("database error"))
},
expectedError: true,
@@ -149,7 +150,7 @@ func TestCheckDomainAvailable(t *testing.T) {
tt.setupMock(mockStore)
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, tt.domain, tt.excludeServiceID)
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
if tt.expectedError {
require.Error(t, err)
@@ -167,6 +168,7 @@ func TestCheckDomainAvailable(t *testing.T) {
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
t.Run("empty domain", func(t *testing.T) {
ctrl := gomock.NewController(t)
@@ -174,11 +176,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, "").
GetServiceByDomain(ctx, accountID, "").
Return(nil, status.Errorf(status.NotFound, "not found"))
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, "", "")
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
assert.NoError(t, err)
})
@@ -189,11 +191,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, "test.com").
GetServiceByDomain(ctx, accountID, "test.com").
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, "test.com", "")
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
assert.Error(t, err)
sErr, ok := status.FromError(err)
@@ -207,11 +209,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, "nil.com").
GetServiceByDomain(ctx, accountID, "nil.com").
Return(nil, nil)
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, "nil.com", "")
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
assert.NoError(t, err)
})
@@ -239,7 +241,7 @@ func TestPersistNewService(t *testing.T) {
// Create another mock for the transaction
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, "new.com").
GetServiceByDomain(ctx, accountID, "new.com").
Return(nil, status.Errorf(status.NotFound, "not found"))
txMock.EXPECT().
CreateService(ctx, service).
@@ -270,7 +272,7 @@ func TestPersistNewService(t *testing.T) {
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, "existing.com").
GetServiceByDomain(ctx, accountID, "existing.com").
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
return fn(txMock)
@@ -423,9 +425,8 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(srv.Close)
return srv
}
@@ -704,9 +705,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -720,7 +720,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
domains: []string{"test.netbird.io"},
},
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
mgr.exposeTracker = &exposeTracker{manager: mgr}
return mgr, testStore
}
@@ -814,7 +814,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
assert.NotEmpty(t, resp.ServiceURL, "service URL should be set")
// Verify service is persisted in store
persisted, err := testStore.GetServiceByDomain(ctx, resp.Domain)
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err)
assert.Equal(t, resp.Domain, persisted.Domain)
assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
@@ -977,7 +977,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
require.NoError(t, err)
// Verify service is deleted
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.Error(t, err, "service should be deleted")
})
@@ -1012,43 +1012,41 @@ func TestStopServiceFromPeer(t *testing.T) {
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.Error(t, err, "service should be deleted")
})
}
func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
func TestDeleteService_UntracksEphemeralExpose(t *testing.T) {
ctx := context.Background()
mgr, testStore := setupIntegrationTest(t)
mgr, _ := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
assert.Equal(t, 1, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be tracked after create")
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(1), count, "one ephemeral service should exist after create")
svc, err := testStore.GetServiceByDomain(ctx, resp.Domain)
// Look up the service by domain to get its store ID
svc, err := mgr.store.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err)
// Delete via the API path (user-initiated)
err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID)
require.NoError(t, err)
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete")
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be untracked after API delete")
// A new expose should succeed (not blocked by stale tracking)
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9090,
Protocol: "http",
})
assert.NoError(t, err, "new expose should succeed after API delete")
assert.NoError(t, err, "new expose should succeed after API delete cleared tracking")
}
func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) {
func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) {
ctx := context.Background()
mgr, _ := setupIntegrationTest(t)
@@ -1060,16 +1058,12 @@ func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) {
require.NoError(t, err)
}
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(3), count, "all ephemeral services should exist")
assert.Equal(t, 3, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be tracked")
err = mgr.DeleteAllServices(ctx, testAccountID, testUserID)
err := mgr.DeleteAllServices(ctx, testAccountID, testUserID)
require.NoError(t, err)
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
require.NoError(t, err)
assert.Equal(t, int64(0), count, "all ephemeral services should be deleted after DeleteAllServices")
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be untracked after DeleteAllServices")
}
func TestRenewServiceFromPeer(t *testing.T) {
@@ -1136,9 +1130,8 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)

View File

@@ -6,16 +6,13 @@ import (
"fmt"
"math/big"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
@@ -52,25 +49,17 @@ const (
SourceEphemeral = "ephemeral"
)
type TargetOptions struct {
SkipTLSVerify bool `json:"skip_tls_verify"`
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
}
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
Options TargetOptions `gorm:"embedded" json:"options"`
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
}
type PasswordAuthConfig struct {
@@ -134,7 +123,7 @@ type Service struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"type:varchar(255);uniqueIndex"`
Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool
@@ -144,8 +133,8 @@ type Service struct {
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
SourcePeer string `gorm:"index:idx_service_source_peer"`
Source string `gorm:"default:'permanent'"`
SourcePeer string
}
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
@@ -205,7 +194,7 @@ func (s *Service) ToAPIResponse() *api.Service {
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
st := api.ServiceTarget{
apiTargets = append(apiTargets, api.ServiceTarget{
Path: target.Path,
Host: &target.Host,
Port: target.Port,
@@ -213,9 +202,7 @@ func (s *Service) ToAPIResponse() *api.Service {
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
}
st.Options = targetOptionsToAPI(target.Options)
apiTargets = append(apiTargets, st)
})
}
meta := api.ServiceMeta{
@@ -269,14 +256,10 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
if target.Path != nil {
path = *target.Path
}
pm := &proto.PathMapping{
pathMappings = append(pathMappings, &proto.PathMapping{
Path: path,
Target: targetURL.String(),
}
pm.Options = targetOptionsToProto(target.Options)
pathMappings = append(pathMappings, pm)
})
}
auth := &proto.Authentication{
@@ -329,87 +312,13 @@ func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
// PathRewriteMode controls how the request path is rewritten before forwarding.
type PathRewriteMode string
const (
PathRewritePreserve PathRewriteMode = "preserve"
)
func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
switch mode {
case PathRewritePreserve:
return proto.PathRewriteMode_PATH_REWRITE_PRESERVE
default:
return proto.PathRewriteMode_PATH_REWRITE_DEFAULT
}
}
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
return nil
}
apiOpts := &api.ServiceTargetOptions{}
if opts.SkipTLSVerify {
apiOpts.SkipTlsVerify = &opts.SkipTLSVerify
}
if opts.RequestTimeout != 0 {
s := opts.RequestTimeout.String()
apiOpts.RequestTimeout = &s
}
if opts.PathRewrite != "" {
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
apiOpts.PathRewrite = &pr
}
if len(opts.CustomHeaders) > 0 {
apiOpts.CustomHeaders = &opts.CustomHeaders
}
return apiOpts
}
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
return nil
}
popts := &proto.PathTargetOptions{
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
}
if opts.RequestTimeout != 0 {
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
}
return popts
}
func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
var opts TargetOptions
if o.SkipTlsVerify != nil {
opts.SkipTLSVerify = *o.SkipTlsVerify
}
if o.RequestTimeout != nil {
d, err := time.ParseDuration(*o.RequestTimeout)
if err != nil {
return opts, fmt.Errorf("target %d: parse request_timeout %q: %w", idx, *o.RequestTimeout, err)
}
opts.RequestTimeout = d
}
if o.PathRewrite != nil {
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
}
if o.CustomHeaders != nil {
opts.CustomHeaders = *o.CustomHeaders
}
return opts, nil
}
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) error {
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
for i, apiTarget := range req.Targets {
for _, apiTarget := range req.Targets {
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
@@ -422,13 +331,6 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
if apiTarget.Options != nil {
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
if err != nil {
return err
}
target.Options = opts
}
targets = append(targets, target)
}
s.Targets = targets
@@ -466,8 +368,6 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
}
s.Auth.BearerAuth = bearerAuth
}
return nil
}
func (s *Service) Validate() error {
@@ -500,113 +400,11 @@ func (s *Service) Validate() error {
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
if err := validateTargetOptions(i, &target.Options); err != nil {
return err
}
}
return nil
}
const (
maxRequestTimeout = 5 * time.Minute
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
)
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
var httpHeaderNameRe = regexp.MustCompile(`^[!#$%&'*+\-.^_` + "`" + `|~0-9A-Za-z]+$`)
// hopByHopHeaders are headers that must not be set as custom headers
// because they are connection-level and stripped by the proxy.
var hopByHopHeaders = map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Proxy-Connection": {},
"Te": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
}
// reservedHeaders are set authoritatively by the proxy or control HTTP framing
// and cannot be overridden.
var reservedHeaders = map[string]struct{}{
"Content-Length": {},
"Content-Type": {},
"Cookie": {},
"Forwarded": {},
"X-Forwarded-For": {},
"X-Forwarded-Host": {},
"X-Forwarded-Port": {},
"X-Forwarded-Proto": {},
"X-Real-Ip": {},
}
func validateTargetOptions(idx int, opts *TargetOptions) error {
if opts.PathRewrite != "" && opts.PathRewrite != PathRewritePreserve {
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
}
if opts.RequestTimeout != 0 {
if opts.RequestTimeout <= 0 {
return fmt.Errorf("target %d: request_timeout must be positive", idx)
}
if opts.RequestTimeout > maxRequestTimeout {
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
}
}
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
return err
}
return nil
}
func validateCustomHeaders(idx int, headers map[string]string) error {
if len(headers) > maxCustomHeaders {
return fmt.Errorf("target %d: custom_headers count %d exceeds maximum of %d", idx, len(headers), maxCustomHeaders)
}
seen := make(map[string]string, len(headers))
for key, value := range headers {
if !httpHeaderNameRe.MatchString(key) {
return fmt.Errorf("target %d: custom header key %q is not a valid HTTP header name", idx, key)
}
if len(key) > maxHeaderKeyLen {
return fmt.Errorf("target %d: custom header key %q exceeds maximum length of %d", idx, key, maxHeaderKeyLen)
}
if len(value) > maxHeaderValueLen {
return fmt.Errorf("target %d: custom header %q value exceeds maximum length of %d", idx, key, maxHeaderValueLen)
}
if containsCRLF(key) || containsCRLF(value) {
return fmt.Errorf("target %d: custom header %q contains invalid characters", idx, key)
}
canonical := http.CanonicalHeaderKey(key)
if prev, ok := seen[canonical]; ok {
return fmt.Errorf("target %d: custom header keys %q and %q collide (both canonicalize to %q)", idx, prev, key, canonical)
}
seen[canonical] = key
if _, ok := hopByHopHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is a hop-by-hop header and cannot be set", idx, key)
}
if _, ok := reservedHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is managed by the proxy and cannot be overridden", idx, key)
}
if canonical == "Host" {
return fmt.Errorf("target %d: use pass_host_header instead of setting Host as a custom header", idx)
}
}
return nil
}
func containsCRLF(s string) bool {
return strings.ContainsAny(s, "\r\n")
}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
}
@@ -619,12 +417,6 @@ func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
if len(target.Options.CustomHeaders) > 0 {
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
for k, v := range target.Options.CustomHeaders {
targetCopy.Options.CustomHeaders[k] = v
}
}
targets[i] = &targetCopy
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -88,188 +87,6 @@ func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
assert.Contains(t, err.Error(), "empty target_id")
}
func TestValidateTargetOptions_PathRewrite(t *testing.T) {
tests := []struct {
name string
mode PathRewriteMode
wantErr string
}{
{"empty is default", "", ""},
{"preserve is valid", PathRewritePreserve, ""},
{"unknown rejected", "regex", "unknown path_rewrite mode"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.PathRewrite = tt.mode
err := rp.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
wantErr string
}{
{"valid 30s", 30 * time.Second, ""},
{"valid 2m", 2 * time.Minute, ""},
{"zero is fine", 0, ""},
{"negative", -1 * time.Second, "must be positive"},
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.RequestTimeout = tt.timeout
err := rp.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestValidateTargetOptions_CustomHeaders(t *testing.T) {
t.Run("valid headers", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{
"X-Custom": "value",
"X-Trace": "abc123",
}
assert.NoError(t, rp.Validate())
})
t.Run("CRLF in key", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Bad\r\nKey": "value"}
assert.ErrorContains(t, rp.Validate(), "not a valid HTTP header name")
})
t.Run("CRLF in value", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Good": "bad\nvalue"}
assert.ErrorContains(t, rp.Validate(), "invalid characters")
})
t.Run("hop-by-hop header rejected", func(t *testing.T) {
for _, h := range []string{"Connection", "Transfer-Encoding", "Keep-Alive", "Upgrade", "Proxy-Connection"} {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
assert.ErrorContains(t, rp.Validate(), "hop-by-hop", "header %q should be rejected", h)
}
})
t.Run("reserved header rejected", func(t *testing.T) {
for _, h := range []string{"X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Forwarded-Port", "Cookie", "Forwarded", "Content-Length", "Content-Type"} {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
assert.ErrorContains(t, rp.Validate(), "managed by the proxy", "header %q should be rejected", h)
}
})
t.Run("Host header rejected", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"Host": "evil.com"}
assert.ErrorContains(t, rp.Validate(), "pass_host_header")
})
t.Run("too many headers", func(t *testing.T) {
rp := validProxy()
headers := make(map[string]string, 17)
for i := range 17 {
headers[fmt.Sprintf("X-H%d", i)] = "v"
}
rp.Targets[0].Options.CustomHeaders = headers
assert.ErrorContains(t, rp.Validate(), "exceeds maximum of 16")
})
t.Run("key too long", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{strings.Repeat("X", 129): "v"}
assert.ErrorContains(t, rp.Validate(), "key")
assert.ErrorContains(t, rp.Validate(), "exceeds maximum length")
})
t.Run("value too long", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Ok": strings.Repeat("v", 4097)}
assert.ErrorContains(t, rp.Validate(), "value exceeds maximum length")
})
t.Run("duplicate canonical keys rejected", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{
"x-custom": "a",
"X-Custom": "b",
}
assert.ErrorContains(t, rp.Validate(), "collide")
})
}
func TestToProtoMapping_TargetOptions(t *testing.T) {
rp := &Service{
ID: "svc-1",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
Enabled: true,
Options: TargetOptions{
SkipTLSVerify: true,
RequestTimeout: 30 * time.Second,
PathRewrite: PathRewritePreserve,
CustomHeaders: map[string]string{"X-Custom": "val"},
},
},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
opts := pm.Path[0].Options
require.NotNil(t, opts, "options should be populated")
assert.True(t, opts.SkipTlsVerify)
assert.Equal(t, proto.PathRewriteMode_PATH_REWRITE_PRESERVE, opts.PathRewrite)
assert.Equal(t, map[string]string{"X-Custom": "val"}, opts.CustomHeaders)
require.NotNil(t, opts.RequestTimeout)
assert.Equal(t, int64(30), opts.RequestTimeout.Seconds)
}
func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
rp := &Service{
ID: "svc-1",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Nil(t, pm.Path[0].Options, "options should be nil when all defaults")
}
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string

View File

@@ -168,7 +168,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())
@@ -203,16 +203,6 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
})
}
func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
return Create(s, func() *nbgrpc.PKCEVerifierStore {
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create PKCE verifier store: %v", err)
}
return pkceStore
})
}
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
return Create(s, func() accesslogs.Manager {
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())

View File

@@ -210,7 +210,7 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager())
return &m
})
}

View File

@@ -28,13 +28,9 @@ import (
"github.com/netbirdio/netbird/version"
)
const (
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
ManagementLegacyPort = 33073
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
DefaultSelfHostedDomain = "netbird.selfhosted"
)
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
const ManagementLegacyPort = 33073
type Server interface {
Start(ctx context.Context) error
@@ -62,7 +58,6 @@ type BaseServer struct {
mgmtMetricsPort int
mgmtPort int
disableLegacyManagementPort bool
autoResolveDomains bool
proxyAuthClose func()
@@ -86,7 +81,6 @@ type Config struct {
DisableMetrics bool
DisableGeoliteUpdate bool
UserDeleteFromIDPEnabled bool
AutoResolveDomains bool
}
// NewServer initializes and configures a new Server instance
@@ -102,7 +96,6 @@ func NewServer(cfg *Config) *BaseServer {
mgmtPort: cfg.MgmtPort,
disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
mgmtMetricsPort: cfg.MgmtMetricsPort,
autoResolveDomains: cfg.AutoResolveDomains,
}
}
@@ -116,10 +109,6 @@ func (s *BaseServer) Start(ctx context.Context) error {
s.cancel = cancel
s.errCh = make(chan error, 4)
if s.autoResolveDomains {
s.resolveDomains(srvCtx)
}
s.PeersManager()
s.GeoLocationManager()
@@ -248,6 +237,7 @@ func (s *BaseServer) Stop() error {
_ = s.certManager.Listener().Close()
}
s.GRPCServer().Stop()
s.ReverseProxyGRPCServer().Close()
if s.proxyAuthClose != nil {
s.proxyAuthClose()
s.proxyAuthClose = nil
@@ -391,60 +381,6 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
}()
}
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
// Fresh installs use the default self-hosted domain, while existing installs reuse the
// persisted account domain to keep addressing stable across config changes.
func (s *BaseServer) resolveDomains(ctx context.Context) {
st := s.Store()
setDefault := func(logMsg string, args ...any) {
if logMsg != "" {
log.WithContext(ctx).Warnf(logMsg, args...)
}
s.dnsDomain = DefaultSelfHostedDomain
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
}
accountsCount, err := st.GetAccountsCounter(ctx)
if err != nil {
setDefault("resolve domains: failed to read accounts counter: %v; using default domain %q", err, DefaultSelfHostedDomain)
return
}
if accountsCount == 0 {
s.dnsDomain = DefaultSelfHostedDomain
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
log.WithContext(ctx).Infof("resolve domains: fresh install detected, using default domain %q", DefaultSelfHostedDomain)
return
}
accountID, err := st.GetAnyAccountID(ctx)
if err != nil {
setDefault("resolve domains: failed to get existing account ID: %v; using default domain %q", err, DefaultSelfHostedDomain)
return
}
if accountID == "" {
setDefault("resolve domains: empty account ID returned for existing accounts; using default domain %q", DefaultSelfHostedDomain)
return
}
domain, _, err := st.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
setDefault("resolve domains: failed to get account domain for account %q: %v; using default domain %q", accountID, err, DefaultSelfHostedDomain)
return
}
if domain == "" {
setDefault("resolve domains: account %q has empty domain; using default domain %q", accountID, DefaultSelfHostedDomain)
return
}
s.dnsDomain = domain
s.mgmtSingleAccModeDomain = domain
log.WithContext(ctx).Infof("resolve domains: using persisted account domain %q", domain)
}
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {

View File

@@ -1,63 +0,0 @@
package server
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
)
func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), nil)
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
}
func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(1), nil)
mockStore.EXPECT().GetAnyAccountID(gomock.Any()).Return("acc-1", nil)
mockStore.EXPECT().GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "acc-1").Return("vpn.mycompany.com", "", nil)
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
}
func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), errors.New("db failed"))
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
}

View File

@@ -1,61 +0,0 @@
package grpc
import (
"context"
"fmt"
"time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
)
// PKCEVerifierStore manages PKCE verifiers for OAuth flows.
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
type PKCEVerifierStore struct {
cache *cache.Cache[string]
ctx context.Context
}
// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection
func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
}
return &PKCEVerifierStore{
cache: cache.New[string](cacheStore),
ctx: ctx,
}, nil
}
// Store saves a PKCE verifier associated with an OAuth state parameter.
// The verifier is stored with the specified TTL and will be automatically deleted after expiration.
func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error {
if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil {
return fmt.Errorf("failed to store PKCE verifier: %w", err)
}
log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl)
return nil
}
// LoadAndDelete retrieves and removes a PKCE verifier for the given state.
// Returns the verifier and true if found, or empty string and false if not found.
// This enforces single-use semantics for PKCE verifiers.
func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) {
verifier, err := s.cache.Get(s.ctx, state)
if err != nil {
log.Debugf("PKCE verifier not found for state")
return "", false
}
if err := s.cache.Delete(s.ctx, state); err != nil {
log.Warnf("Failed to delete PKCE verifier for state: %v", err)
}
return verifier, true
}

View File

@@ -18,6 +18,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -82,12 +83,20 @@ type ProxyServiceServer struct {
// OIDC configuration for proxy authentication
oidcConfig ProxyOIDCConfig
// Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore
// TODO: use database to store these instead?
// pkceVerifiers stores PKCE code verifiers keyed by OAuth state.
// Entries expire after pkceVerifierTTL to prevent unbounded growth.
pkceVerifiers sync.Map
pkceCleanupCancel context.CancelFunc
}
const pkceVerifierTTL = 10 * time.Minute
type pkceEntry struct {
verifier string
createdAt time.Time
}
// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
@@ -99,21 +108,42 @@ type proxyConnection struct {
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx := context.Background()
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
pkceCleanupCancel: cancel,
}
go s.cleanupPKCEVerifiers(ctx)
go s.cleanupStaleProxies(ctx)
return s
}
// cleanupPKCEVerifiers periodically removes expired PKCE verifiers.
func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
ticker := time.NewTicker(pkceVerifierTTL)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
s.pkceVerifiers.Range(func(key, value any) bool {
if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL {
s.pkceVerifiers.Delete(key)
}
return true
})
}
}
}
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
@@ -130,6 +160,11 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
}
}
// Close stops background goroutines.
func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel()
}
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
s.serviceManager = manager
}
@@ -142,7 +177,11 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
ctx := stream.Context()
peerInfo := PeerIPFromContext(ctx)
peerInfo := ""
if p, ok := peer.FromContext(ctx); ok {
peerInfo = p.Addr.String()
}
log.Infof("New proxy connection from %s", peerInfo)
proxyID := req.GetProxyId()
@@ -756,10 +795,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
codeVerifier := oauth2.GenerateVerifier()
if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil {
log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err)
return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err)
}
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()})
return &proto.GetOIDCURLResponse{
Url: (&oauth2.Config{
@@ -796,10 +832,18 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
// ValidateState validates the state parameter from an OAuth callback.
// Returns the original redirect URL if valid, or an error if invalid.
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
v, ok := s.pkceVerifiers.LoadAndDelete(state)
if !ok {
return "", "", errors.New("no verifier for state")
}
entry, ok := v.(pkceEntry)
if !ok {
return "", "", errors.New("invalid verifier for state")
}
if time.Since(entry.createdAt) > pkceVerifierTTL {
return "", "", errors.New("PKCE verifier expired")
}
verifier = entry.verifier
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
parts := strings.Split(state, "|")

View File

@@ -71,12 +71,14 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter
return handler(ctx, req)
}
token, err := interceptor.validateProxyToken(ctx)
if err != nil {
// Log auth failures explicitly; gRPC doesn't log these by default.
log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
return nil, err
}
// token, err := interceptor.validateProxyToken(ctx)
// if err != nil {
// // Log auth failures explicitly; gRPC doesn't log these by default.
// log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
// return nil, err
// }
token := &types.ProxyAccessToken{ID: "dummy"}
ctx = context.WithValue(ctx, ProxyTokenContextKey, token)
return handler(ctx, req)
@@ -87,12 +89,13 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter
return handler(srv, ss)
}
token, err := interceptor.validateProxyToken(ss.Context())
if err != nil {
// Log auth failures explicitly; gRPC doesn't log these by default.
log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
return err
}
// token, err := interceptor.validateProxyToken(ss.Context())
// if err != nil {
// // Log auth failures explicitly; gRPC doesn't log these by default.
// log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
// return err
// }
token := &types.ProxyAccessToken{ID: "dummy"} // TODO: Implement token validation for streaming methods.
ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token)
wrapped := &wrappedServerStream{
@@ -107,7 +110,7 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter
}
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
clientIP := PeerIPFromContext(ctx)
clientIP := peerIPFromContext(ctx)
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")

View File

@@ -115,9 +115,9 @@ func (l *authFailureLimiter) stop() {
l.cancel()
}
// PeerIPFromContext extracts the client IP from the gRPC context.
// peerIPFromContext extracts the client IP from the gRPC context.
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
func PeerIPFromContext(ctx context.Context) string {
func peerIPFromContext(ctx context.Context) clientIP {
if addr, ok := realip.FromContext(ctx); ok {
return addr.String()
}

View File

@@ -5,10 +5,11 @@ import (
"crypto/rand"
"encoding/base64"
"strings"
"sync"
"testing"
"time"
"sync"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -93,16 +94,11 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda
}
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
@@ -155,16 +151,11 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
}
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
@@ -194,16 +185,11 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
}
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
@@ -255,15 +241,10 @@ func generateState(s *ProxyServiceServer, redirectURL string) string {
}
func TestOAuthState_NeverTheSame(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
pkceVerifierStore: pkceStore,
}
redirectURL := "https://app.example.com/callback"
@@ -284,43 +265,31 @@ func TestOAuthState_NeverTheSame(t *testing.T) {
}
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
pkceVerifierStore: pkceStore,
}
// Old format had only 2 parts: base64(url)|hmac
err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
require.NoError(t, err)
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err = s.ValidateState("base64url|hmac")
_, _, err := s.ValidateState("base64url|hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state format")
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
pkceVerifierStore: pkceStore,
}
// Store with tampered HMAC
err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
require.NoError(t, err)
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac")
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature")
}

View File

@@ -41,10 +41,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)

View File

@@ -1379,10 +1379,9 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
err := am.updateUserAuthWithSingleMode(ctx, &userAuth)
if err != nil {
return "", "", err
}
userAuth.Domain = am.singleAccountModeDomain
userAuth.DomainCategory = types.PrivateCategory
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
@@ -1415,35 +1414,6 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil
}
// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account
func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error {
userAuth.DomainCategory = types.PrivateCategory
userAuth.Domain = am.singleAccountModeDomain
accountID, err := am.Store.GetAnyAccountID(ctx)
if err != nil {
if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound {
return err
}
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
return nil
}
if accountID == "" {
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
return nil
}
domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
userAuth.Domain = domain
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
return nil
}
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager

View File

@@ -15,7 +15,6 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/prometheus/client_golang/prometheus/push"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -3133,7 +3132,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err
@@ -3967,116 +3966,3 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
}
}
func TestUpdateUserAuthWithSingleMode(t *testing.T) {
t.Run("sets defaults and overrides domain from store", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("account-1", nil)
mockStore.EXPECT().
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
Return("real-domain.com", "private", nil)
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "real-domain.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", nil)
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "fallback.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", status.Errorf(status.NotFound, "no accounts"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "fallback.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", status.Errorf(status.Internal, "db down"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.Error(t, err)
assert.Contains(t, err.Error(), "db down")
// Defaults should still be set before error path
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("account-1", nil)
mockStore.EXPECT().
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
Return("", "", status.Errorf(status.Internal, "query failed"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.Error(t, err)
assert.Contains(t, err.Error(), "query failed")
})
}

View File

@@ -220,13 +220,6 @@ const (
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
AccountPeerExposeDisabled Activity = 115
// DomainAdded indicates that a user added a custom domain
DomainAdded Activity = 116
// DomainDeleted indicates that a user deleted a custom domain
DomainDeleted Activity = 117
// DomainValidated indicates that a custom domain was validated
DomainValidated Activity = 118
AccountDeleted Activity = 99999
)
@@ -371,10 +364,6 @@ var activityMap = map[Activity]Code{
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
DomainAdded: {"Domain added", "domain.add"},
DomainDeleted: {"Domain deleted", "domain.delete"},
DomainValidated: {"Domain validated", "domain.validate"},
}
// StringCode returns a string code of the activity

View File

@@ -193,9 +193,6 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
usersManager := users.NewManager(testStore)
oidcConfig := nbgrpc.ProxyOIDCConfig{
@@ -209,7 +206,6 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
pkceStore,
oidcConfig,
nil,
usersManager,

View File

@@ -98,17 +98,13 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create PKCE verifier store: %v", err)
}
noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err)

View File

@@ -294,9 +294,9 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
localUsers++
} else {
idpUsers++
idpType := extractIdpType(idpID)
embeddedIdpTypes[idpType]++
}
idpType := extractIdpType(idpID)
embeddedIdpTypes[idpType]++
}
}
}
@@ -531,9 +531,6 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
// Connector IDs are formatted as "<type>-<xid>" (e.g., "okta-abc123", "zitadel-xyz").
// Returns the type prefix, or "oidc" if no known prefix is found.
func extractIdpType(connectorID string) string {
if connectorID == "local" {
return "local"
}
idx := strings.LastIndex(connectorID, "-")
if idx <= 0 {
return "oidc"

View File

@@ -29,7 +29,6 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
localUserID := dex.EncodeDexUserID("10", "local")
idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0")
oidcUserID := dex.EncodeDexUserID("30", "d6jvvp69kmnc73c9pl40")
return []*types.Account{
{
Id: "1",
@@ -207,13 +206,6 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
"1": {},
},
},
oidcUserID: {
Id: oidcUserID,
IsServiceUser: false,
PATs: map[string]*types.PersonalAccessToken{
"1": {},
},
},
},
Networks: []*networkTypes.Network{
{
@@ -286,14 +278,14 @@ func TestGenerateProperties(t *testing.T) {
if properties["rules"] != 4 {
t.Errorf("expected 4 rules, got %d", properties["rules"])
}
if properties["users"] != 3 {
t.Errorf("expected 3 users, got %d", properties["users"])
if properties["users"] != 2 {
t.Errorf("expected 1 users, got %d", properties["users"])
}
if properties["setup_keys_usage"] != 2 {
t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"])
}
if properties["pats"] != 5 {
t.Errorf("expected 5 personal_access_tokens, got %d", properties["pats"])
if properties["pats"] != 4 {
t.Errorf("expected 4 personal_access_tokens, got %d", properties["pats"])
}
if properties["peers_ssh_enabled"] != 2 {
t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"])
@@ -377,20 +369,14 @@ func TestGenerateProperties(t *testing.T) {
if properties["local_users_count"] != 1 {
t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"])
}
if properties["idp_users_count"] != 2 {
t.Errorf("expected 2 idp_users_count, got %d", properties["idp_users_count"])
}
if properties["embedded_idp_users_local"] != 1 {
t.Errorf("expected 1 embedded_idp_users_local, got %v", properties["embedded_idp_users_local"])
if properties["idp_users_count"] != 1 {
t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"])
}
if properties["embedded_idp_users_zitadel"] != 1 {
t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"])
}
if properties["embedded_idp_users_oidc"] != 1 {
t.Errorf("expected 1 embedded_idp_users_oidc, got %v", properties["embedded_idp_users_oidc"])
}
if properties["embedded_idp_count"] != 3 {
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
if properties["embedded_idp_count"] != 1 {
t.Errorf("expected 1 embedded_idp_count, got %v", properties["embedded_idp_count"])
}
if properties["services"] != 2 {
@@ -450,8 +436,7 @@ func TestExtractIdpType(t *testing.T) {
{"microsoft-abc123", "microsoft"},
{"authentik-abc123", "authentik"},
{"keycloak-d5uv82dra0haedlf6kv0", "keycloak"},
{"local", "local"},
{"d6jvvp69kmnc73c9pl40", "oidc"},
{"local", "oidc"},
{"", "oidc"},
}

View File

@@ -4977,9 +4977,9 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
return service, nil
}
func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
var service *rpservice.Service
result := s.db.Preload("Targets").Where("domain = ?", domain).First(&service)
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain)
@@ -5040,99 +5040,6 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS
return serviceList, nil
}
// RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service.
func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Update("meta_last_renewed_at", time.Now())
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error)
return status.Errorf(status.Internal, "renew ephemeral service")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
}
return nil
}
// GetExpiredEphemeralServices returns ephemeral services whose last renewal exceeds the given TTL.
// Only the fields needed for reaping are selected. The limit parameter caps the batch size to
// avoid loading too many rows in a single tick. Rows with empty source_peer are excluded to
// skip malformed legacy data.
func (s *SqlStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) {
cutoff := time.Now().Add(-ttl)
var services []*rpservice.Service
result := s.db.
Select("id", "account_id", "source_peer", "domain").
Where("source = ? AND source_peer <> '' AND meta_last_renewed_at < ?", rpservice.SourceEphemeral, cutoff).
Limit(limit).
Find(&services)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get expired ephemeral services: %v", result.Error)
return nil, status.Errorf(status.Internal, "get expired ephemeral services")
}
return services, nil
}
// CountEphemeralServicesByPeer returns the count of ephemeral services for a specific peer.
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
// The locking is applied via a row-level SELECT ... FOR UPDATE (not on the aggregate) to
// stay compatible with Postgres, which disallows FOR UPDATE on COUNT(*).
func (s *SqlStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
if lockStrength == LockingStrengthNone {
var count int64
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
Count(&count)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
return 0, status.Errorf(status.Internal, "count ephemeral services")
}
return count, nil
}
var ids []string
result := s.db.Model(&rpservice.Service{}).
Clauses(clause.Locking{Strength: string(lockStrength)}).
Select("id").
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
Pluck("id", &ids)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
return 0, status.Errorf(status.Internal, "count ephemeral services")
}
return int64(len(ids)), nil
}
// EphemeralServiceExists checks if an ephemeral service exists for the given peer and domain.
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
if lockStrength == LockingStrengthNone {
var count int64
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Count(&count)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
return false, status.Errorf(status.Internal, "check ephemeral service existence")
}
return count > 0, nil
}
var id string
result := s.db.Model(&rpservice.Service{}).
Clauses(clause.Locking{Strength: string(lockStrength)}).
Select("id").
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Limit(1).
Pluck("id", &id)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
return false, status.Errorf(status.Internal, "check ephemeral service existence")
}
return id != "", nil
}
func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) {
tx := s.db

View File

@@ -257,15 +257,10 @@ type Store interface {
UpdateService(ctx context.Context, service *rpservice.Service) error
DeleteService(ctx context.Context, accountID, serviceID string) error
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error
GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error)
CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error)
EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)

View File

@@ -208,21 +208,6 @@ func (mr *MockStoreMockRecorder) CountAccountsByPrivateDomain(ctx, domain interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountsByPrivateDomain", reflect.TypeOf((*MockStore)(nil).CountAccountsByPrivateDomain), ctx, domain)
}
// CountEphemeralServicesByPeer mocks base method.
func (m *MockStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountEphemeralServicesByPeer", ctx, lockStrength, accountID, peerID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountEphemeralServicesByPeer indicates an expected call of CountEphemeralServicesByPeer.
func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
}
// CreateAccessLog mocks base method.
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
m.ctrl.T.Helper()
@@ -701,21 +686,6 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
}
// EphemeralServiceExists mocks base method.
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EphemeralServiceExists", ctx, lockStrength, accountID, peerID, domain)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// EphemeralServiceExists indicates an expected call of EphemeralServiceExists.
func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
}
// ExecuteInTransaction mocks base method.
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
m.ctrl.T.Helper()
@@ -1392,21 +1362,6 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID)
}
// GetExpiredEphemeralServices mocks base method.
func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetExpiredEphemeralServices", ctx, ttl, limit)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetExpiredEphemeralServices indicates an expected call of GetExpiredEphemeralServices.
func (mr *MockStoreMockRecorder) GetExpiredEphemeralServices(ctx, ttl, limit interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiredEphemeralServices", reflect.TypeOf((*MockStore)(nil).GetExpiredEphemeralServices), ctx, ttl, limit)
}
// GetGroupByID mocks base method.
func (m *MockStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types2.Group, error) {
m.ctrl.T.Helper()
@@ -1932,18 +1887,18 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
}
// GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, domain)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, accountID, domain)
}
// GetServiceByID mocks base method.
@@ -2446,20 +2401,6 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResourceFromGroup", reflect.TypeOf((*MockStore)(nil).RemoveResourceFromGroup), ctx, accountId, groupID, resourceID)
}
// RenewEphemeralService mocks base method.
func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// RenewEphemeralService indicates an expected call of RenewEphemeralService.
func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, domain)
}
// RevokeProxyAccessToken mocks base method.
func (m *MockStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
m.ctrl.T.Helper()

View File

@@ -1,185 +0,0 @@
package telemetry
import (
"context"
"math"
"sync"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
)
// AccountDurationAggregator uses OpenTelemetry histograms per account to calculate P95
// without publishing individual account labels
type AccountDurationAggregator struct {
mu sync.RWMutex
accounts map[string]*accountHistogram
meterProvider *sdkmetric.MeterProvider
manualReader *sdkmetric.ManualReader
FlushInterval time.Duration
MaxAge time.Duration
ctx context.Context
}
type accountHistogram struct {
histogram metric.Int64Histogram
lastUpdate time.Time
}
// NewAccountDurationAggregator creates aggregator using OTel histograms
func NewAccountDurationAggregator(ctx context.Context, flushInterval, maxAge time.Duration) *AccountDurationAggregator {
manualReader := sdkmetric.NewManualReader(
sdkmetric.WithTemporalitySelector(func(kind sdkmetric.InstrumentKind) metricdata.Temporality {
return metricdata.DeltaTemporality
}),
)
meterProvider := sdkmetric.NewMeterProvider(
sdkmetric.WithReader(manualReader),
)
return &AccountDurationAggregator{
accounts: make(map[string]*accountHistogram),
meterProvider: meterProvider,
manualReader: manualReader,
FlushInterval: flushInterval,
MaxAge: maxAge,
ctx: ctx,
}
}
// Record adds a duration for an account using OTel histogram
func (a *AccountDurationAggregator) Record(accountID string, duration time.Duration) {
a.mu.Lock()
defer a.mu.Unlock()
accHist, exists := a.accounts[accountID]
if !exists {
meter := a.meterProvider.Meter("account-aggregator")
histogram, err := meter.Int64Histogram(
"sync_duration_per_account",
metric.WithUnit("milliseconds"),
)
if err != nil {
return
}
accHist = &accountHistogram{
histogram: histogram,
}
a.accounts[accountID] = accHist
}
accHist.histogram.Record(a.ctx, duration.Milliseconds(),
metric.WithAttributes(attribute.String("account_id", accountID)))
accHist.lastUpdate = time.Now()
}
// FlushAndGetP95s extracts P95 from each account's histogram
func (a *AccountDurationAggregator) FlushAndGetP95s() []int64 {
a.mu.Lock()
defer a.mu.Unlock()
var rm metricdata.ResourceMetrics
err := a.manualReader.Collect(a.ctx, &rm)
if err != nil {
return nil
}
now := time.Now()
p95s := make([]int64, 0, len(a.accounts))
for _, scopeMetrics := range rm.ScopeMetrics {
for _, metric := range scopeMetrics.Metrics {
histogramData, ok := metric.Data.(metricdata.Histogram[int64])
if !ok {
continue
}
for _, dataPoint := range histogramData.DataPoints {
a.processDataPoint(dataPoint, now, &p95s)
}
}
}
a.cleanupStaleAccounts(now)
return p95s
}
// processDataPoint extracts P95 from a single histogram data point
func (a *AccountDurationAggregator) processDataPoint(dataPoint metricdata.HistogramDataPoint[int64], now time.Time, p95s *[]int64) {
accountID := extractAccountID(dataPoint)
if accountID == "" {
return
}
if p95 := calculateP95FromHistogram(dataPoint); p95 > 0 {
*p95s = append(*p95s, p95)
}
}
// cleanupStaleAccounts removes accounts that haven't been updated recently
func (a *AccountDurationAggregator) cleanupStaleAccounts(now time.Time) {
for accountID := range a.accounts {
if a.isStaleAccount(accountID, now) {
delete(a.accounts, accountID)
}
}
}
// extractAccountID retrieves the account_id from histogram data point attributes
func extractAccountID(dp metricdata.HistogramDataPoint[int64]) string {
for _, attr := range dp.Attributes.ToSlice() {
if attr.Key == "account_id" {
return attr.Value.AsString()
}
}
return ""
}
// isStaleAccount checks if an account hasn't been updated recently
func (a *AccountDurationAggregator) isStaleAccount(accountID string, now time.Time) bool {
accHist, exists := a.accounts[accountID]
if !exists {
return false
}
return now.Sub(accHist.lastUpdate) > a.MaxAge
}
// calculateP95FromHistogram computes P95 from OTel histogram data
func calculateP95FromHistogram(dp metricdata.HistogramDataPoint[int64]) int64 {
if dp.Count == 0 {
return 0
}
targetCount := uint64(math.Ceil(float64(dp.Count) * 0.95))
if targetCount == 0 {
targetCount = 1
}
var cumulativeCount uint64
for i, bucketCount := range dp.BucketCounts {
cumulativeCount += bucketCount
if cumulativeCount >= targetCount {
if i < len(dp.Bounds) {
return int64(dp.Bounds[i])
}
if maxVal, defined := dp.Max.Value(); defined {
return maxVal
}
return dp.Sum / int64(dp.Count)
}
}
return dp.Sum / int64(dp.Count)
}
// Shutdown cleans up resources
func (a *AccountDurationAggregator) Shutdown() error {
return a.meterProvider.Shutdown(a.ctx)
}

View File

@@ -1,219 +0,0 @@
package telemetry
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDeltaTemporality_P95ReflectsCurrentWindow(t *testing.T) {
// Verify that with delta temporality, each flush window only reflects
// recordings since the last flush — not all-time data.
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// Window 1: Record 100 slow requests (500ms each)
for range 100 {
agg.Record("account-A", 500*time.Millisecond)
}
p95sWindow1 := agg.FlushAndGetP95s()
require.Len(t, p95sWindow1, 1, "should have P95 for one account")
firstP95 := p95sWindow1[0]
assert.GreaterOrEqual(t, firstP95, int64(200),
"first window P95 should reflect the 500ms recordings")
// Window 2: Record 100 FAST requests (10ms each)
for range 100 {
agg.Record("account-A", 10*time.Millisecond)
}
p95sWindow2 := agg.FlushAndGetP95s()
require.Len(t, p95sWindow2, 1, "should have P95 for one account")
secondP95 := p95sWindow2[0]
// With delta temporality the P95 should drop significantly because
// the first window's slow recordings are no longer included.
assert.Less(t, secondP95, firstP95,
"second window P95 should be lower than first — delta temporality "+
"ensures each window only reflects recent recordings")
}
func TestEqualWeightPerAccount(t *testing.T) {
// Verify that each account contributes exactly one P95 value,
// regardless of how many requests it made.
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// Account A: 10,000 requests at 500ms (noisy customer)
for range 10000 {
agg.Record("account-A", 500*time.Millisecond)
}
// Accounts B, C, D: 10 requests each at 50ms (normal customers)
for _, id := range []string{"account-B", "account-C", "account-D"} {
for range 10 {
agg.Record(id, 50*time.Millisecond)
}
}
p95s := agg.FlushAndGetP95s()
// Should get exactly 4 P95 values — one per account
assert.Len(t, p95s, 4, "each account should contribute exactly one P95")
}
func TestStaleAccountEviction(t *testing.T) {
ctx := context.Background()
// Use a very short MaxAge so we can test staleness
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
agg.Record("account-A", 100*time.Millisecond)
agg.Record("account-B", 200*time.Millisecond)
// Both accounts should appear
p95s := agg.FlushAndGetP95s()
assert.Len(t, p95s, 2, "both accounts should have P95 values")
// Wait for account-A to become stale, then only update account-B
time.Sleep(60 * time.Millisecond)
agg.Record("account-B", 200*time.Millisecond)
p95s = agg.FlushAndGetP95s()
assert.Len(t, p95s, 1, "both accounts should have P95 values")
// account-A should have been evicted from the accounts map
agg.mu.RLock()
_, accountAExists := agg.accounts["account-A"]
_, accountBExists := agg.accounts["account-B"]
agg.mu.RUnlock()
assert.False(t, accountAExists, "stale account-A should be evicted from map")
assert.True(t, accountBExists, "active account-B should remain in map")
}
func TestStaleAccountEviction_DoesNotReappear(t *testing.T) {
// Verify that with delta temporality, an evicted stale account does not
// reappear in subsequent flushes.
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
agg.Record("account-stale", 100*time.Millisecond)
// Wait for it to become stale
time.Sleep(60 * time.Millisecond)
// First flush: should detect staleness and evict
_ = agg.FlushAndGetP95s()
agg.mu.RLock()
_, exists := agg.accounts["account-stale"]
agg.mu.RUnlock()
assert.False(t, exists, "account should be evicted after first flush")
// Second flush: with delta temporality, the stale account should NOT reappear
p95sSecond := agg.FlushAndGetP95s()
assert.Empty(t, p95sSecond,
"evicted account should not reappear in subsequent flushes with delta temporality")
}
func TestP95Calculation_SingleSample(t *testing.T) {
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
agg.Record("account-A", 150*time.Millisecond)
p95s := agg.FlushAndGetP95s()
require.Len(t, p95s, 1)
// With a single sample, P95 should be the bucket bound containing 150ms
assert.Greater(t, p95s[0], int64(0), "P95 of a single sample should be positive")
}
func TestP95Calculation_AllSameValue(t *testing.T) {
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// All samples are 100ms — P95 should be the bucket bound containing 100ms
for range 100 {
agg.Record("account-A", 100*time.Millisecond)
}
p95s := agg.FlushAndGetP95s()
require.Len(t, p95s, 1)
assert.Greater(t, p95s[0], int64(0))
}
func TestMultipleAccounts_IndependentP95s(t *testing.T) {
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// Account A: all fast (10ms)
for range 100 {
agg.Record("account-fast", 10*time.Millisecond)
}
// Account B: all slow (5000ms)
for range 100 {
agg.Record("account-slow", 5000*time.Millisecond)
}
p95s := agg.FlushAndGetP95s()
require.Len(t, p95s, 2, "should have two P95 values")
// Find min and max — they should differ significantly
minP95 := p95s[0]
maxP95 := p95s[1]
if minP95 > maxP95 {
minP95, maxP95 = maxP95, minP95
}
assert.Less(t, minP95, int64(1000),
"fast account P95 should be well under 1000ms")
assert.Greater(t, maxP95, int64(1000),
"slow account P95 should be well over 1000ms")
}

View File

@@ -13,24 +13,18 @@ const HighLatencyThreshold = time.Second * 7
// GRPCMetrics are gRPC server metrics
type GRPCMetrics struct {
meter metric.Meter
syncRequestsCounter metric.Int64Counter
syncRequestsBlockedCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter
loginRequestsBlockedCounter metric.Int64Counter
loginRequestHighLatencyCounter metric.Int64Counter
getKeyRequestsCounter metric.Int64Counter
activeStreamsGauge metric.Int64ObservableGauge
syncRequestDuration metric.Int64Histogram
syncRequestDurationP95ByAccount metric.Int64Histogram
loginRequestDuration metric.Int64Histogram
loginRequestDurationP95ByAccount metric.Int64Histogram
channelQueueLength metric.Int64Histogram
ctx context.Context
// Per-account aggregation
syncDurationAggregator *AccountDurationAggregator
loginDurationAggregator *AccountDurationAggregator
meter metric.Meter
syncRequestsCounter metric.Int64Counter
syncRequestsBlockedCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter
loginRequestsBlockedCounter metric.Int64Counter
loginRequestHighLatencyCounter metric.Int64Counter
getKeyRequestsCounter metric.Int64Counter
activeStreamsGauge metric.Int64ObservableGauge
syncRequestDuration metric.Int64Histogram
loginRequestDuration metric.Int64Histogram
channelQueueLength metric.Int64Histogram
ctx context.Context
}
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
@@ -99,14 +93,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
syncRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.sync.request.duration.p95.by.account.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("P95 duration of sync requests aggregated per account - each data point represents one account's P95"),
)
if err != nil {
return nil, err
}
loginRequestDuration, err := meter.Int64Histogram("management.grpc.login.request.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of the login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
@@ -115,14 +101,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
loginRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.login.request.duration.p95.by.account.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("P95 duration of login requests aggregated per account - each data point represents one account's P95"),
)
if err != nil {
return nil, err
}
// We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time
// Then we should be able to extract min, manx, mean and the percentiles.
// TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100)
@@ -135,32 +113,20 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
syncDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
loginDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
grpcMetrics := &GRPCMetrics{
meter: meter,
syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
loginRequestsCounter: loginRequestsCounter,
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
getKeyRequestsCounter: getKeyRequestsCounter,
activeStreamsGauge: activeStreamsGauge,
syncRequestDuration: syncRequestDuration,
syncRequestDurationP95ByAccount: syncRequestDurationP95ByAccount,
loginRequestDuration: loginRequestDuration,
loginRequestDurationP95ByAccount: loginRequestDurationP95ByAccount,
channelQueueLength: channelQueue,
ctx: ctx,
syncDurationAggregator: syncDurationAggregator,
loginDurationAggregator: loginDurationAggregator,
}
go grpcMetrics.startSyncP95Flusher()
go grpcMetrics.startLoginP95Flusher()
return grpcMetrics, err
return &GRPCMetrics{
meter: meter,
syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
loginRequestsCounter: loginRequestsCounter,
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
getKeyRequestsCounter: getKeyRequestsCounter,
activeStreamsGauge: activeStreamsGauge,
syncRequestDuration: syncRequestDuration,
loginRequestDuration: loginRequestDuration,
channelQueueLength: channelQueue,
ctx: ctx,
}, err
}
// CountSyncRequest counts the number of gRPC sync requests coming to the gRPC API
@@ -191,9 +157,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() {
// CountLoginRequestDuration counts the duration of the login gRPC requests
func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
grpcMetrics.loginDurationAggregator.Record(accountID, duration)
if duration > HighLatencyThreshold {
grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
}
@@ -202,44 +165,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
// CountSyncRequestDuration counts the duration of the sync gRPC requests
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
grpcMetrics.syncDurationAggregator.Record(accountID, duration)
}
// startSyncP95Flusher periodically flushes per-account sync P95 values to the histogram
func (grpcMetrics *GRPCMetrics) startSyncP95Flusher() {
ticker := time.NewTicker(grpcMetrics.syncDurationAggregator.FlushInterval)
defer ticker.Stop()
for {
select {
case <-grpcMetrics.ctx.Done():
return
case <-ticker.C:
p95s := grpcMetrics.syncDurationAggregator.FlushAndGetP95s()
for _, p95 := range p95s {
grpcMetrics.syncRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
}
}
}
}
// startLoginP95Flusher periodically flushes per-account login P95 values to the histogram
func (grpcMetrics *GRPCMetrics) startLoginP95Flusher() {
ticker := time.NewTicker(grpcMetrics.loginDurationAggregator.FlushInterval)
defer ticker.Stop()
for {
select {
case <-grpcMetrics.ctx.Done():
return
case <-ticker.C:
p95s := grpcMetrics.loginDurationAggregator.FlushAndGetP95s()
for _, p95 := range p95s {
grpcMetrics.loginRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
}
}
}
}
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.

View File

@@ -3,7 +3,6 @@ package accesslog
import (
"context"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
@@ -14,23 +13,6 @@ import (
"github.com/netbirdio/netbird/shared/management/proto"
)
const (
requestThreshold = 10000 // Log every 10k requests
bytesThreshold = 1024 * 1024 * 1024 // Log every 1GB
usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour
usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours
)
type domainUsage struct {
requestCount int64
requestStartTime time.Time
bytesTransferred int64
bytesStartTime time.Time
lastActivity time.Time // Track last activity for cleanup
}
type gRPCClient interface {
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
}
@@ -40,11 +22,6 @@ type Logger struct {
client gRPCClient
logger *log.Logger
trustedProxies []netip.Prefix
usageMux sync.Mutex
domainUsage map[string]*domainUsage
cleanupCancel context.CancelFunc
}
// NewLogger creates a new access log Logger. The trustedProxies parameter
@@ -54,26 +31,10 @@ func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Pre
if logger == nil {
logger = log.StandardLogger()
}
ctx, cancel := context.WithCancel(context.Background())
l := &Logger{
return &Logger{
client: client,
logger: logger,
trustedProxies: trustedProxies,
domainUsage: make(map[string]*domainUsage),
cleanupCancel: cancel,
}
// Start background cleanup routine
go l.cleanupStaleUsage(ctx)
return l
}
// Close stops the cleanup routine. Should be called during graceful shutdown.
func (l *Logger) Close() {
if l.cleanupCancel != nil {
l.cleanupCancel()
}
}
@@ -90,8 +51,6 @@ type logEntry struct {
AuthMechanism string
UserId string
AuthSuccess bool
BytesUpload int64
BytesDownload int64
}
func (l *Logger) log(ctx context.Context, entry logEntry) {
@@ -125,8 +84,6 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
AuthMechanism: entry.AuthMechanism,
UserId: entry.UserId,
AuthSuccess: entry.AuthSuccess,
BytesUpload: entry.BytesUpload,
BytesDownload: entry.BytesDownload,
},
}); err != nil {
// If it fails to send on the gRPC connection, then at least log it to the error log.
@@ -146,82 +103,3 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
}
}()
}
// trackUsage records request and byte counts per domain, logging when thresholds are hit.
func (l *Logger) trackUsage(domain string, bytesTransferred int64) {
if domain == "" {
return
}
l.usageMux.Lock()
defer l.usageMux.Unlock()
now := time.Now()
usage, exists := l.domainUsage[domain]
if !exists {
usage = &domainUsage{
requestStartTime: now,
bytesStartTime: now,
lastActivity: now,
}
l.domainUsage[domain] = usage
}
usage.lastActivity = now
usage.requestCount++
if usage.requestCount >= requestThreshold {
elapsed := time.Since(usage.requestStartTime)
l.logger.WithFields(log.Fields{
"domain": domain,
"requests": usage.requestCount,
"duration": elapsed.String(),
}).Infof("domain %s had %d requests over %s", domain, usage.requestCount, elapsed)
usage.requestCount = 0
usage.requestStartTime = now
}
usage.bytesTransferred += bytesTransferred
if usage.bytesTransferred >= bytesThreshold {
elapsed := time.Since(usage.bytesStartTime)
bytesInGB := float64(usage.bytesTransferred) / (1024 * 1024 * 1024)
l.logger.WithFields(log.Fields{
"domain": domain,
"bytes": usage.bytesTransferred,
"bytes_gb": bytesInGB,
"duration": elapsed.String(),
}).Infof("domain %s transferred %.2f GB over %s", domain, bytesInGB, elapsed)
usage.bytesTransferred = 0
usage.bytesStartTime = now
}
}
// cleanupStaleUsage removes usage entries for domains that have been inactive.
func (l *Logger) cleanupStaleUsage(ctx context.Context) {
ticker := time.NewTicker(usageCleanupPeriod)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
l.usageMux.Lock()
now := time.Now()
removed := 0
for domain, usage := range l.domainUsage {
if now.Sub(usage.lastActivity) > usageInactiveWindow {
delete(l.domainUsage, domain)
removed++
}
}
l.usageMux.Unlock()
if removed > 0 {
l.logger.Debugf("cleaned up %d stale domain usage entries", removed)
}
}
}
}

View File

@@ -32,14 +32,6 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
status: http.StatusOK,
}
var bytesRead int64
if r.Body != nil {
r.Body = &bodyCounter{
ReadCloser: r.Body,
bytesRead: &bytesRead,
}
}
// Resolve the source IP using trusted proxy configuration before passing
// the request on, as the proxy will modify forwarding headers.
sourceIp := extractSourceIP(r, l.trustedProxies)
@@ -61,9 +53,6 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
host = r.Host
}
bytesUpload := bytesRead
bytesDownload := sw.bytesWritten
entry := logEntry{
ID: requestID,
ServiceId: capturedData.GetServiceId(),
@@ -77,15 +66,10 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
AuthMechanism: capturedData.GetAuthMethod(),
UserId: capturedData.GetUserID(),
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
BytesUpload: bytesUpload,
BytesDownload: bytesDownload,
}
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
l.log(r.Context(), entry)
// Track usage for cost monitoring (upload + download) by domain
l.trackUsage(host, bytesUpload+bytesDownload)
})
}

View File

@@ -1,39 +1,18 @@
package accesslog
import (
"io"
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
)
// statusWriter captures the HTTP status code and bytes written from responses.
// statusWriter captures the HTTP status code from WriteHeader calls.
// It embeds responsewriter.PassthroughWriter which handles all the optional
// interfaces (Hijacker, Flusher, Pusher) automatically.
type statusWriter struct {
*responsewriter.PassthroughWriter
status int
bytesWritten int64
status int
}
func (w *statusWriter) WriteHeader(status int) {
w.status = status
w.PassthroughWriter.WriteHeader(status)
}
func (w *statusWriter) Write(b []byte) (int, error) {
n, err := w.PassthroughWriter.Write(b)
w.bytesWritten += int64(n)
return n, err
}
// bodyCounter wraps an io.ReadCloser and counts bytes read from the request body.
type bodyCounter struct {
io.ReadCloser
bytesRead *int64
}
func (bc *bodyCounter) Read(p []byte) (int, error) {
n, err := bc.ReadCloser.Read(p)
*bc.bytesRead += int64(n)
return n, err
}

View File

@@ -42,10 +42,6 @@ type domainInfo struct {
err string
}
type metricsRecorder interface {
RecordCertificateIssuance(duration time.Duration)
}
// Manager wraps autocert.Manager with domain tracking and cross-replica
// coordination via a pluggable locking strategy. The locker prevents
// duplicate ACME requests when multiple replicas share a certificate cache.
@@ -59,7 +55,6 @@ type Manager struct {
certNotifier certificateNotifier
logger *log.Logger
metrics metricsRecorder
}
// NewManager creates a new ACME certificate manager. The certDir is used
@@ -68,7 +63,7 @@ type Manager struct {
// eabKID and eabHMACKey are optional External Account Binding credentials
// required for some CAs like ZeroSSL. The eabHMACKey should be the base64
// URL-encoded string provided by the CA.
func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod, metrics metricsRecorder) *Manager {
func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager {
if logger == nil {
logger = log.StandardLogger()
}
@@ -78,7 +73,6 @@ func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificat
domains: make(map[domain.Domain]*domainInfo),
certNotifier: notifier,
logger: logger,
metrics: metrics,
}
var eab *acme.ExternalAccountBinding
@@ -135,45 +129,24 @@ func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) {
// prefetchCertificate proactively triggers certificate generation for a domain.
// It acquires a distributed lock to prevent multiple replicas from issuing
// duplicate ACME requests. If the certificate appears in the cache while waiting
// for the lock, it cancels the wait and uses the cached certificate.
// duplicate ACME requests. The second replica will block until the first
// finishes, then find the certificate in the cache.
func (mgr *Manager) prefetchCertificate(d domain.Domain) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
name := d.PunycodeString()
if mgr.certExistsInCache(ctx, name) {
mgr.logger.Infof("certificate for domain %q already exists in cache before lock attempt", name)
mgr.loadAndFinalizeCachedCert(ctx, d, name)
return
}
go mgr.pollCacheAndCancel(ctx, name, cancel)
// Acquire lock
mgr.logger.Infof("acquiring cert lock for domain %q", name)
lockStart := time.Now()
unlock, err := mgr.locker.Lock(ctx, name)
if err != nil {
if mgr.certExistsInCache(context.Background(), name) {
mgr.logger.Infof("certificate for domain %q appeared in cache while waiting for lock", name)
mgr.loadAndFinalizeCachedCert(context.Background(), d, name)
return
}
mgr.logger.Warnf("acquire cert lock for domain %q: %v", name, err)
// Continue without lock
mgr.logger.Warnf("acquire cert lock for domain %q, proceeding without lock: %v", name, err)
} else {
mgr.logger.Infof("acquired cert lock for domain %q in %s", name, time.Since(lockStart))
defer unlock()
}
if mgr.certExistsInCache(ctx, name) {
mgr.logger.Infof("certificate for domain %q already exists in cache after lock acquisition, skipping ACME request", name)
mgr.loadAndFinalizeCachedCert(ctx, d, name)
return
}
hello := &tls.ClientHelloInfo{
ServerName: name,
Conn: &dummyConn{ctx: ctx},
@@ -188,10 +161,6 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) {
return
}
if mgr.metrics != nil {
mgr.metrics.RecordCertificateIssuance(elapsed)
}
mgr.setDomainState(d, domainReady, "")
now := time.Now()
@@ -230,78 +199,6 @@ func (mgr *Manager) setDomainState(d domain.Domain, state domainState, errMsg st
}
}
// certExistsInCache checks if a certificate exists on the shared disk for the given domain.
func (mgr *Manager) certExistsInCache(ctx context.Context, domain string) bool {
if mgr.Cache == nil {
return false
}
_, err := mgr.Cache.Get(ctx, domain)
return err == nil
}
// pollCacheAndCancel periodically checks if a certificate appears in the cache.
// If found, it cancels the context to abort the lock wait.
func (mgr *Manager) pollCacheAndCancel(ctx context.Context, domain string, cancel context.CancelFunc) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if mgr.certExistsInCache(context.Background(), domain) {
mgr.logger.Debugf("cert detected in cache for domain %q, cancelling lock wait", domain)
cancel()
return
}
}
}
}
// loadAndFinalizeCachedCert loads a certificate from cache and updates domain state.
func (mgr *Manager) loadAndFinalizeCachedCert(ctx context.Context, d domain.Domain, name string) {
hello := &tls.ClientHelloInfo{
ServerName: name,
Conn: &dummyConn{ctx: ctx},
}
cert, err := mgr.GetCertificate(hello)
if err != nil {
mgr.logger.Warnf("load cached certificate for domain %q: %v", name, err)
mgr.setDomainState(d, domainFailed, err.Error())
return
}
mgr.setDomainState(d, domainReady, "")
now := time.Now()
if cert != nil && cert.Leaf != nil {
leaf := cert.Leaf
mgr.logger.Infof("loaded cached certificate for domain %q: serial=%s SANs=%v notBefore=%s, notAfter=%s, now=%s",
name,
leaf.SerialNumber.Text(16),
leaf.DNSNames,
leaf.NotBefore.UTC().Format(time.RFC3339),
leaf.NotAfter.UTC().Format(time.RFC3339),
now.UTC().Format(time.RFC3339),
)
mgr.logCertificateDetails(name, leaf, now)
} else {
mgr.logger.Infof("loaded cached certificate for domain %q", name)
}
mgr.mu.RLock()
info := mgr.domains[d]
mgr.mu.RUnlock()
if info != nil && mgr.certNotifier != nil {
if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.serviceID, name); err != nil {
mgr.logger.Warnf("notify certificate ready for domain %q: %v", name, err)
}
}
}
// logCertificateDetails logs certificate validity and SCT timestamps.
func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate, now time.Time) {
if cert.NotBefore.After(now) {

View File

@@ -10,7 +10,7 @@ import (
)
func TestHostPolicy(t *testing.T) {
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil)
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "")
mgr.AddDomain("example.com", "acc1", "rp1")
// Wait for the background prefetch goroutine to finish so the temp dir
@@ -70,7 +70,7 @@ func TestHostPolicy(t *testing.T) {
}
func TestDomainStates(t *testing.T) {
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil)
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "")
assert.Equal(t, 0, mgr.PendingCerts(), "initially zero")
assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains")

View File

@@ -1,106 +1,64 @@
package metrics
import (
"context"
"net/http"
"sync"
"strconv"
"time"
"go.opentelemetry.io/otel/metric"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
)
type Metrics struct {
ctx context.Context
requestsTotal metric.Int64Counter
activeRequests metric.Int64UpDownCounter
configuredDomains metric.Int64UpDownCounter
totalPaths metric.Int64UpDownCounter
requestDuration metric.Int64Histogram
backendDuration metric.Int64Histogram
certificateIssueDuration metric.Int64Histogram
mappingsMux sync.Mutex
mappingPaths map[string]int
requestsTotal prometheus.Counter
activeRequests prometheus.Gauge
configuredDomains prometheus.Gauge
pathsPerDomain *prometheus.GaugeVec
requestDuration *prometheus.HistogramVec
backendDuration *prometheus.HistogramVec
}
func New(ctx context.Context, meter metric.Meter) (*Metrics, error) {
requestsTotal, err := meter.Int64Counter(
"proxy.http.request.counter",
metric.WithUnit("1"),
metric.WithDescription("Total number of requests made to the netbird proxy"),
)
if err != nil {
return nil, err
}
activeRequests, err := meter.Int64UpDownCounter(
"proxy.http.active_requests",
metric.WithUnit("1"),
metric.WithDescription("Current in-flight requests handled by the netbird proxy"),
)
if err != nil {
return nil, err
}
configuredDomains, err := meter.Int64UpDownCounter(
"proxy.domains.count",
metric.WithUnit("1"),
metric.WithDescription("Current number of domains configured on the netbird proxy"),
)
if err != nil {
return nil, err
}
totalPaths, err := meter.Int64UpDownCounter(
"proxy.paths.count",
metric.WithUnit("1"),
metric.WithDescription("Total number of paths configured on the netbird proxy"),
)
if err != nil {
return nil, err
}
requestDuration, err := meter.Int64Histogram(
"proxy.http.request.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of requests made to the netbird proxy"),
)
if err != nil {
return nil, err
}
backendDuration, err := meter.Int64Histogram(
"proxy.backend.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of peer round trip time from the netbird proxy"),
)
if err != nil {
return nil, err
}
certificateIssueDuration, err := meter.Int64Histogram(
"proxy.certificate.issue.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of ACME certificate issuance"),
)
if err != nil {
return nil, err
}
func New(reg prometheus.Registerer) *Metrics {
promFactory := promauto.With(reg)
return &Metrics{
ctx: ctx,
requestsTotal: requestsTotal,
activeRequests: activeRequests,
configuredDomains: configuredDomains,
totalPaths: totalPaths,
requestDuration: requestDuration,
backendDuration: backendDuration,
certificateIssueDuration: certificateIssueDuration,
mappingPaths: make(map[string]int),
}, nil
requestsTotal: promFactory.NewCounter(prometheus.CounterOpts{
Name: "netbird_proxy_requests_total",
Help: "Total number of requests made to the netbird proxy",
}),
activeRequests: promFactory.NewGauge(prometheus.GaugeOpts{
Name: "netbird_proxy_active_requests_count",
Help: "Current in-flight requests handled by the netbird proxy",
}),
configuredDomains: promFactory.NewGauge(prometheus.GaugeOpts{
Name: "netbird_proxy_domains_count",
Help: "Current number of domains configured on the netbird proxy",
}),
pathsPerDomain: promFactory.NewGaugeVec(
prometheus.GaugeOpts{
Name: "netbird_proxy_paths_count",
Help: "Current number of paths configured on the netbird proxy labelled by domain",
},
[]string{"domain"},
),
requestDuration: promFactory.NewHistogramVec(
prometheus.HistogramOpts{
Name: "netbird_proxy_request_duration_seconds",
Help: "Duration of requests made to the netbird proxy",
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
},
[]string{"status", "size", "method", "host", "path"},
),
backendDuration: promFactory.NewHistogramVec(prometheus.HistogramOpts{
Name: "netbird_proxy_backend_duration_seconds",
Help: "Duration of peer round trip time from the netbird proxy",
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
},
[]string{"status", "size", "method", "host", "path"},
),
}
}
type responseInterceptor struct {
@@ -122,19 +80,23 @@ func (w *responseInterceptor) Write(b []byte) (int, error) {
func (m *Metrics) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m.requestsTotal.Add(m.ctx, 1)
m.activeRequests.Add(m.ctx, 1)
m.requestsTotal.Inc()
m.activeRequests.Inc()
interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
start := time.Now()
defer func() {
duration := time.Since(start)
m.activeRequests.Add(m.ctx, -1)
m.requestDuration.Record(m.ctx, duration.Milliseconds())
}()
next.ServeHTTP(interceptor, r)
duration := time.Since(start)
m.activeRequests.Desc()
m.requestDuration.With(prometheus.Labels{
"status": strconv.Itoa(interceptor.status),
"size": strconv.Itoa(interceptor.size),
"method": r.Method,
"host": r.Host,
"path": r.URL.Path,
}).Observe(duration.Seconds())
})
}
@@ -146,52 +108,44 @@ func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(req *http.Request) (*http.Response, error) {
labels := prometheus.Labels{
"method": req.Method,
"host": req.Host,
// Fill potentially empty labels with default values to avoid cardinality issues.
"path": "/",
"status": "0",
"size": "0",
}
if req.URL != nil {
labels["path"] = req.URL.Path
}
start := time.Now()
res, err := next.RoundTrip(req)
duration := time.Since(start)
m.backendDuration.Record(m.ctx, duration.Milliseconds())
// Not all labels will be available if there was an error.
if res != nil {
labels["status"] = strconv.Itoa(res.StatusCode)
labels["size"] = strconv.Itoa(int(res.ContentLength))
}
m.backendDuration.With(labels).Observe(duration.Seconds())
return res, err
})
}
func (m *Metrics) AddMapping(mapping proxy.Mapping) {
m.mappingsMux.Lock()
defer m.mappingsMux.Unlock()
newPathCount := len(mapping.Paths)
oldPathCount, exists := m.mappingPaths[mapping.Host]
if !exists {
m.configuredDomains.Add(m.ctx, 1)
}
pathDelta := newPathCount - oldPathCount
if pathDelta != 0 {
m.totalPaths.Add(m.ctx, int64(pathDelta))
}
m.mappingPaths[mapping.Host] = newPathCount
m.configuredDomains.Inc()
m.pathsPerDomain.With(prometheus.Labels{
"domain": mapping.Host,
}).Set(float64(len(mapping.Paths)))
}
func (m *Metrics) RemoveMapping(mapping proxy.Mapping) {
m.mappingsMux.Lock()
defer m.mappingsMux.Unlock()
oldPathCount, exists := m.mappingPaths[mapping.Host]
if !exists {
// Nothing to remove
return
}
m.configuredDomains.Add(m.ctx, -1)
m.totalPaths.Add(m.ctx, -int64(oldPathCount))
delete(m.mappingPaths, mapping.Host)
}
// RecordCertificateIssuance records the duration of a certificate issuance.
func (m *Metrics) RecordCertificateIssuance(duration time.Duration) {
m.certificateIssueDuration.Record(m.ctx, duration.Milliseconds())
m.configuredDomains.Dec()
m.pathsPerDomain.With(prometheus.Labels{
"domain": mapping.Host,
}).Set(0)
}

View File

@@ -1,17 +1,13 @@
package metrics_test
import (
"context"
"net/http"
"net/url"
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
"go.opentelemetry.io/otel/exporters/prometheus"
"go.opentelemetry.io/otel/sdk/metric"
"github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/prometheus/client_golang/prometheus"
)
type testRoundTripper struct {
@@ -51,19 +47,7 @@ func TestMetrics_RoundTripper(t *testing.T) {
},
}
exporter, err := prometheus.New()
if err != nil {
t.Fatalf("create prometheus exporter: %v", err)
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(metrics.Metrics{}).PkgPath()
meter := provider.Meter(pkg)
m, err := metrics.New(context.Background(), meter)
if err != nil {
t.Fatalf("create metrics: %v", err)
}
m := metrics.New(prometheus.NewRegistry())
for name, test := range tests {
t.Run(name, func(t *testing.T) {

View File

@@ -28,12 +28,10 @@ func BenchmarkServeHTTP(b *testing.B) {
ID: rand.Text(),
AccountID: types.AccountID(rand.Text()),
Host: "app.example.com",
Paths: map[string]*proxy.PathTarget{
Paths: map[string]*url.URL{
"/": {
URL: &url.URL{
Scheme: "http",
Host: "10.0.0.1:8080",
},
Scheme: "http",
Host: "10.0.0.1:8080",
},
},
})
@@ -69,12 +67,10 @@ func BenchmarkServeHTTPHostCount(b *testing.B) {
ID: id,
AccountID: types.AccountID(rand.Text()),
Host: host,
Paths: map[string]*proxy.PathTarget{
Paths: map[string]*url.URL{
"/": {
URL: &url.URL{
Scheme: "http",
Host: "10.0.0.1:8080",
},
Scheme: "http",
Host: "10.0.0.1:8080",
},
},
})
@@ -104,17 +100,15 @@ func BenchmarkServeHTTPPathCount(b *testing.B) {
b.Fatal(err)
}
paths := make(map[string]*proxy.PathTarget, pathCount)
paths := make(map[string]*url.URL, pathCount)
for i := range pathCount {
path := "/" + rand.Text()
if int64(i) == targetIndex.Int64() {
target = path
}
paths[path] = &proxy.PathTarget{
URL: &url.URL{
Scheme: "http",
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
},
paths[path] = &url.URL{
Scheme: "http",
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
}
}
rp.AddMapping(proxy.Mapping{

View File

@@ -80,30 +80,14 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
capturedData.SetAccountId(result.accountID)
}
pt := result.target
if pt.SkipTLSVerify {
ctx = roundtrip.WithSkipTLSVerify(ctx)
}
if pt.RequestTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout)
defer cancel()
}
rewriteMatchedPath := result.matchedPath
if pt.PathRewrite == PathRewritePreserve {
rewriteMatchedPath = ""
}
rp := &httputil.ReverseProxy{
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders),
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
Transport: p.transport,
FlushInterval: -1,
ErrorHandler: proxyErrorHandler,
}
if result.rewriteRedirects {
rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose
}
rp.ServeHTTP(w, r.WithContext(ctx))
}
@@ -113,22 +97,16 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// forwarding headers and stripping proxy authentication credentials.
// When passHostHeader is true, the original client Host header is preserved
// instead of being rewritten to the backend's address.
// The pathRewrite parameter controls how the request path is transformed.
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string) func(r *httputil.ProxyRequest) {
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool) func(r *httputil.ProxyRequest) {
return func(r *httputil.ProxyRequest) {
switch pathRewrite {
case PathRewritePreserve:
// Keep the full original request path as-is.
default:
if matchedPath != "" && matchedPath != "/" {
// Strip the matched path prefix from the incoming request path before
// SetURL joins it with the target's base path, avoiding path duplication.
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
if r.Out.URL.Path == "" {
r.Out.URL.Path = "/"
}
r.Out.URL.RawPath = ""
// Strip the matched path prefix from the incoming request path before
// SetURL joins it with the target's base path, avoiding path duplication.
if matchedPath != "" && matchedPath != "/" {
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
if r.Out.URL.Path == "" {
r.Out.URL.Path = "/"
}
r.Out.URL.RawPath = ""
}
r.SetURL(target)
@@ -138,10 +116,6 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
r.Out.Host = target.Host
}
for k, v := range customHeaders {
r.Out.Header.Set(k, v)
}
clientIP := extractClientIP(r.In.RemoteAddr)
if IsTrustedProxy(clientIP, p.trustedProxies) {

View File

@@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
t.Run("rewrites host to backend by default", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
@@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
})
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", true)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
@@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
@@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
@@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
@@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects https from TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects http without TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
@@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
// No TLS, but forced to https
@@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced http proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "http"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
@@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
t.Run("strips nb_session cookie", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
@@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
t.Run("strips session_token query parameter", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
@@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080/app")
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
rewrite(pr)
@@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/app", false)
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
rewrite(pr)
@@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/app", false)
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
@@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
@@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Proto", "https")
@@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Port", "8443")
@@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
@@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
// Management builds: path="/heise", target="https://heise.de:443/heise"
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("subpath under prefix also preserved", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// What the behavior WOULD be if target URL had no path (true stripping)
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
@@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// Root path "/" — no stripping expected
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.com:443/")
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
rewrite := p.rewriteFunc(target, "/", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
@@ -546,82 +546,6 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
})
}
func TestRewriteFunc_PreservePath(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
t.Run("preserve keeps full request path", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil)
pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/api/users/123", pr.Out.URL.Path,
"preserve should keep the full original request path")
})
t.Run("preserve with root matchedPath", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil)
pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/anything", pr.Out.URL.Path)
})
}
func TestRewriteFunc_CustomHeaders(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
t.Run("injects custom headers", func(t *testing.T) {
headers := map[string]string{
"X-Custom-Auth": "token-abc",
"X-Env": "production",
}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "token-abc", pr.Out.Header.Get("X-Custom-Auth"))
assert.Equal(t, "production", pr.Out.Header.Get("X-Env"))
})
t.Run("nil customHeaders is fine", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
})
t.Run("custom headers override existing request headers", func(t *testing.T) {
headers := map[string]string{"X-Override": "new-value"}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("X-Override", "old-value")
rewrite(pr)
assert.Equal(t, "new-value", pr.Out.Header.Get("X-Override"))
})
}
func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"})
pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/api/deep/path", pr.Out.URL.Path, "preserve should keep the full original path")
assert.Equal(t, "proxy", pr.Out.Header.Get("X-Via"), "custom header should be set")
}
func TestRewriteLocationFunc(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }

View File

@@ -6,41 +6,21 @@ import (
"net/url"
"sort"
"strings"
"time"
"github.com/netbirdio/netbird/proxy/internal/types"
)
// PathRewriteMode controls how the request path is rewritten before forwarding.
type PathRewriteMode int
const (
// PathRewriteDefault strips the matched prefix and joins with the target path.
PathRewriteDefault PathRewriteMode = iota
// PathRewritePreserve keeps the full original request path as-is.
PathRewritePreserve
)
// PathTarget holds a backend URL and per-target behavioral options.
type PathTarget struct {
URL *url.URL
SkipTLSVerify bool
RequestTimeout time.Duration
PathRewrite PathRewriteMode
CustomHeaders map[string]string
}
type Mapping struct {
ID string
AccountID types.AccountID
Host string
Paths map[string]*PathTarget
Paths map[string]*url.URL
PassHostHeader bool
RewriteRedirects bool
}
type targetResult struct {
target *PathTarget
url *url.URL
matchedPath string
serviceID string
accountID types.AccountID
@@ -75,14 +55,10 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
for _, path := range paths {
if strings.HasPrefix(req.URL.Path, path) {
pt := m.Paths[path]
if pt == nil || pt.URL == nil {
p.logger.Warnf("invalid mapping for host: %s, path: %s (nil target)", host, path)
continue
}
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, pt.URL)
target := m.Paths[path]
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target)
return targetResult{
target: pt,
url: target,
matchedPath: path,
serviceID: m.ID,
accountID: m.AccountID,

View File

@@ -1,32 +0,0 @@
package roundtrip
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/proxy/internal/types"
)
func TestAccountIDContext(t *testing.T) {
t.Run("returns empty when missing", func(t *testing.T) {
assert.Equal(t, types.AccountID(""), AccountIDFromContext(context.Background()))
})
t.Run("round-trips value", func(t *testing.T) {
ctx := WithAccountID(context.Background(), "acc-123")
assert.Equal(t, types.AccountID("acc-123"), AccountIDFromContext(ctx))
})
}
func TestSkipTLSVerifyContext(t *testing.T) {
t.Run("false by default", func(t *testing.T) {
assert.False(t, skipTLSVerifyFromContext(context.Background()))
})
t.Run("true when set", func(t *testing.T) {
ctx := WithSkipTLSVerify(context.Background())
assert.True(t, skipTLSVerifyFromContext(ctx))
})
}

View File

@@ -2,7 +2,6 @@ package roundtrip
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
@@ -53,12 +52,9 @@ type domainNotification struct {
type clientEntry struct {
client *embed.Client
transport *http.Transport
// insecureTransport is a clone of transport with TLS verification disabled,
// used when per-target skip_tls_verify is set.
insecureTransport *http.Transport
domains map[domain.Domain]domainInfo
createdAt time.Time
started bool
domains map[domain.Domain]domainInfo
createdAt time.Time
started bool
// Per-backend in-flight limiting keyed by target host:port.
// TODO: clean up stale entries when backend targets change.
inflightMu sync.Mutex
@@ -134,9 +130,6 @@ type ClientDebugInfo struct {
// accountIDContextKey is the context key for storing the account ID.
type accountIDContextKey struct{}
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
type skipTLSVerifyContextKey struct{}
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
// one is created by authenticating with the management server using the provided token.
// Multiple domains can share the same client.
@@ -256,33 +249,27 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// Create a transport using the client dialer. We do this instead of using
// the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests.
transport := &http.Transport{
DialContext: client.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: n.transportCfg.maxIdleConns,
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
IdleConnTimeout: n.transportCfg.idleConnTimeout,
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
WriteBufferSize: n.transportCfg.writeBufferSize,
ReadBufferSize: n.transportCfg.readBufferSize,
DisableCompression: n.transportCfg.disableCompression,
}
insecureTransport := transport.Clone()
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec
return &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
transport: transport,
insecureTransport: insecureTransport,
createdAt: time.Now(),
started: false,
inflightMap: make(map[backendKey]chan struct{}),
maxInflight: n.transportCfg.maxInflight,
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
transport: &http.Transport{
DialContext: client.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: n.transportCfg.maxIdleConns,
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
IdleConnTimeout: n.transportCfg.idleConnTimeout,
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
WriteBufferSize: n.transportCfg.writeBufferSize,
ReadBufferSize: n.transportCfg.readBufferSize,
DisableCompression: n.transportCfg.disableCompression,
},
createdAt: time.Now(),
started: false,
inflightMap: make(map[backendKey]chan struct{}),
maxInflight: n.transportCfg.maxInflight,
}, nil
}
@@ -386,7 +373,6 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
client := entry.client
transport := entry.transport
insecureTransport := entry.insecureTransport
delete(n.clients, accountID)
n.clientsMux.Unlock()
@@ -401,7 +387,6 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
@@ -430,9 +415,6 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
}
client := entry.client
transport := entry.transport
if skipTLSVerifyFromContext(req.Context()) {
transport = entry.insecureTransport
}
n.clientsMux.RUnlock()
release, ok := entry.acquireInflight(req.URL.Host)
@@ -475,7 +457,6 @@ func (n *NetBird) StopAll(ctx context.Context) error {
var merr *multierror.Error
for accountID, entry := range n.clients {
entry.transport.CloseIdleConnections()
entry.insecureTransport.CloseIdleConnections()
if err := entry.client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
@@ -598,14 +579,3 @@ func AccountIDFromContext(ctx context.Context) types.AccountID {
}
return accountID
}
// WithSkipTLSVerify marks the context to use an insecure transport that skips
// TLS certificate verification for the backend connection.
func WithSkipTLSVerify(ctx context.Context) context.Context {
return context.WithValue(ctx, skipTLSVerifyContextKey{}, true)
}
func skipTLSVerifyFromContext(ctx context.Context) bool {
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
return v
}

View File

@@ -116,9 +116,6 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
// Create real users manager
usersManager := users.NewManager(testStore)
@@ -134,7 +131,6 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
pkceStore,
oidcConfig,
nil,
usersManager,

View File

@@ -18,18 +18,16 @@ import (
"net/http"
"net/netip"
"net/url"
"os"
"path/filepath"
"reflect"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/pires/go-proxyproto"
prometheus2 "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/exporters/prometheus"
"go.opentelemetry.io/otel/sdk/metric"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
@@ -45,7 +43,7 @@ import (
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/k8s"
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
@@ -66,7 +64,7 @@ type Server struct {
debug *http.Server
healthServer *health.Server
healthChecker *health.Checker
meter *proxymetrics.Metrics
meter *metrics.Metrics
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
// so they can be closed during graceful shutdown, since http.Server.Shutdown
@@ -155,19 +153,8 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.initDefaults()
exporter, err := prometheus.New()
if err != nil {
return fmt.Errorf("create prometheus exporter: %w", err)
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(Server{}).PkgPath()
meter := provider.Meter(pkg)
s.meter, err = proxymetrics.New(ctx, meter)
if err != nil {
return fmt.Errorf("create metrics: %w", err)
}
reg := prometheus.NewRegistry()
s.meter = metrics.New(reg)
mgmtConn, err := s.dialManagement()
if err != nil {
@@ -194,8 +181,39 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
return err
}
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
// TEMPORARY: Create a test transport that uses direct HTTP (bypasses NetBird tunnel)
testTransport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
WriteBufferSize: 256 * 1024,
ReadBufferSize: 256 * 1024,
}
// TEMPORARY: Start local file server for testing
go func() {
staticFile := os.Getenv("NB_PROXY_STATIC_FILE_PATH")
log.Infof("Reading static file from %s", staticFile)
fileServerMux := http.NewServeMux()
fileServerMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
s.Logger.Debugf("Serving test file to %s", r.RemoteAddr)
http.ServeFile(w, r, staticFile)
})
testServer := &http.Server{
Addr: "0.0.0.0:9999",
Handler: fileServerMux,
}
s.Logger.Info("Started test file server on http://0.0.0.0:9999/")
if err := testServer.ListenAndServe(); err != nil {
s.Logger.Warnf("Test file server error: %v", err)
}
}()
// Configure the reverse proxy using direct transport for testing (bypasses NetBird)
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(testTransport), s.ForwardedProto, s.TrustedProxies, s.Logger)
// TEMPORARY: Add static test mapping pointing to local file server
// Using "/" as the path to match all requests to this host
// Configure the authentication middleware with session validator for OIDC group checks.
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient)
@@ -207,7 +225,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.startDebugEndpoint()
if err := s.startHealthServer(); err != nil {
if err := s.startHealthServer(reg); err != nil {
return err
}
@@ -242,6 +260,22 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
httpsErr <- s.https.ServeTLS(ln, "", "")
}()
hostDomain := os.Getenv("NB_PROXY_FILE_HOST")
testURL, _ := url.Parse("http://127.0.0.1:9999")
s.proxy.AddMapping(proxy.Mapping{
ID: "test-static-file",
AccountID: types.AccountID("test-account"),
Host: hostDomain,
Paths: map[string]*url.URL{
"/": testURL,
},
})
if s.acme != nil {
s.acme.AddDomain(domain.Domain(hostDomain), "test-account", "test-static-file")
}
s.Logger.Info("Added static test mapping: %s/* -> local test file server (bypassing NetBird tunnel)", hostDomain)
select {
case err := <-httpsErr:
s.shutdownServices()
@@ -298,12 +332,12 @@ func (s *Server) startDebugEndpoint() {
}
// startHealthServer launches the health probe and metrics server.
func (s *Server) startHealthServer() error {
func (s *Server) startHealthServer(reg *prometheus.Registry) error {
healthAddr := s.HealthAddress
if healthAddr == "" {
healthAddr = defaultHealthAddr
}
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
healthListener, err := net.Listen("tcp", healthAddr)
if err != nil {
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
@@ -437,7 +471,7 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
"acme_server": s.ACMEDirectory,
"challenge_type": s.ACMEChallengeType,
}).Debug("ACME certificates enabled, configuring certificate manager")
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod, s.meter)
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod)
if s.ACMEChallengeType == "http-01" {
s.http = &http.Server{
@@ -734,7 +768,7 @@ func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping)
}
func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
paths := make(map[string]*proxy.PathTarget)
paths := make(map[string]*url.URL)
for _, pathMapping := range mapping.GetPath() {
targetURL, err := url.Parse(pathMapping.GetTarget())
if err != nil {
@@ -748,17 +782,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
}).WithError(err).Error("failed to parse target URL for path, skipping")
continue
}
pt := &proxy.PathTarget{URL: targetURL}
if opts := pathMapping.GetOptions(); opts != nil {
pt.SkipTLSVerify = opts.GetSkipTlsVerify()
pt.PathRewrite = protoToPathRewrite(opts.GetPathRewrite())
pt.CustomHeaders = opts.GetCustomHeaders()
if d := opts.GetRequestTimeout(); d != nil {
pt.RequestTimeout = d.AsDuration()
}
}
paths[pathMapping.GetPath()] = pt
paths[pathMapping.GetPath()] = targetURL
}
return proxy.Mapping{
ID: mapping.GetId(),
@@ -770,15 +794,6 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
}
}
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {
switch mode {
case proto.PathRewriteMode_PATH_REWRITE_PRESERVE:
return proxy.PathRewritePreserve
default:
return proxy.PathRewriteDefault
}
}
// debugEndpointAddr returns the address for the debug endpoint.
// If addr is empty, it defaults to localhost:8444 for security.
func debugEndpointAddr(addr string) string {

View File

@@ -1,271 +0,0 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var testServiceTarget = api.ServiceTarget{
TargetId: "peer-123",
TargetType: "peer",
Protocol: "https",
Port: 8443,
Enabled: true,
}
var testService = api.Service{
Id: "svc-1",
Name: "test-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Meta: api.ServiceMeta{
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
Status: "active",
},
Targets: []api.ServiceTarget{testServiceTarget},
}
func TestReverseProxyServices_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Service{testService})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.List(context.Background())
require.NoError(t, err)
require.Len(t, ret, 1)
assert.Equal(t, testService.Id, ret[0].Id)
assert.Equal(t, testService.Name, ret[0].Name)
})
}
func TestReverseProxyServices_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestReverseProxyServices_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testService)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Get(context.Background(), "svc-1")
require.NoError(t, err)
assert.Equal(t, testService.Id, ret.Id)
assert.Equal(t, testService.Domain, ret.Domain)
})
}
func TestReverseProxyServices_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Get(context.Background(), "svc-1")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestReverseProxyServices_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.ServiceRequest
require.NoError(t, json.Unmarshal(reqBytes, &req))
assert.Equal(t, "test-service", req.Name)
assert.Equal(t, "test.example.com", req.Domain)
retBytes, _ := json.Marshal(testService)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
Name: "test-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Targets: []api.ServiceTarget{testServiceTarget},
})
require.NoError(t, err)
assert.Equal(t, testService.Id, ret.Id)
})
}
func TestReverseProxyServices_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
Name: "test-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Targets: []api.ServiceTarget{testServiceTarget},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestReverseProxyServices_Create_WithPerTargetOptions(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.ServiceRequest
require.NoError(t, json.Unmarshal(reqBytes, &req))
require.Len(t, req.Targets, 1)
target := req.Targets[0]
require.NotNil(t, target.Options, "options should be present")
opts := target.Options
require.NotNil(t, opts.SkipTlsVerify, "skip_tls_verify should be present")
assert.True(t, *opts.SkipTlsVerify)
require.NotNil(t, opts.RequestTimeout, "request_timeout should be present")
assert.Equal(t, "30s", *opts.RequestTimeout)
require.NotNil(t, opts.PathRewrite, "path_rewrite should be present")
assert.Equal(t, api.ServiceTargetOptionsPathRewrite("preserve"), *opts.PathRewrite)
require.NotNil(t, opts.CustomHeaders, "custom_headers should be present")
assert.Equal(t, "bar", (*opts.CustomHeaders)["X-Foo"])
retBytes, _ := json.Marshal(testService)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
pathRewrite := api.ServiceTargetOptionsPathRewrite("preserve")
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
Name: "test-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Targets: []api.ServiceTarget{
{
TargetId: "peer-123",
TargetType: "peer",
Protocol: "https",
Port: 8443,
Enabled: true,
Options: &api.ServiceTargetOptions{
SkipTlsVerify: ptr(true),
RequestTimeout: ptr("30s"),
PathRewrite: &pathRewrite,
CustomHeaders: &map[string]string{"X-Foo": "bar"},
},
},
},
})
require.NoError(t, err)
assert.Equal(t, testService.Id, ret.Id)
})
}
func TestReverseProxyServices_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.ServiceRequest
require.NoError(t, json.Unmarshal(reqBytes, &req))
assert.Equal(t, "updated-service", req.Name)
retBytes, _ := json.Marshal(testService)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Update(context.Background(), "svc-1", api.PutApiReverseProxiesServicesServiceIdJSONRequestBody{
Name: "updated-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Targets: []api.ServiceTarget{testServiceTarget},
})
require.NoError(t, err)
assert.Equal(t, testService.Id, ret.Id)
})
}
func TestReverseProxyServices_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Update(context.Background(), "svc-1", api.PutApiReverseProxiesServicesServiceIdJSONRequestBody{
Name: "updated-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Targets: []api.ServiceTarget{testServiceTarget},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestReverseProxyServices_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.ReverseProxyServices.Delete(context.Background(), "svc-1")
require.NoError(t, err)
})
}
func TestReverseProxyServices_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.ReverseProxyServices.Delete(context.Background(), "svc-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}

View File

@@ -2822,16 +2822,6 @@ components:
type: string
description: "City name from geolocation"
example: "San Francisco"
bytes_upload:
type: integer
format: int64
description: "Bytes uploaded (request body size)"
example: 1024
bytes_download:
type: integer
format: int64
description: "Bytes downloaded (response body size)"
example: 8192
required:
- id
- service_id
@@ -2841,8 +2831,6 @@ components:
- path
- duration_ms
- status_code
- bytes_upload
- bytes_download
ProxyAccessLogsResponse:
type: object
properties:
@@ -3039,28 +3027,6 @@ components:
- targets
- auth
- enabled
ServiceTargetOptions:
type: object
properties:
skip_tls_verify:
type: boolean
description: Skip TLS certificate verification for this backend
request_timeout:
type: string
description: Per-target response timeout as a Go duration string (e.g. "30s", "2m")
path_rewrite:
type: string
description: Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path.
enum: [preserve]
custom_headers:
type: object
description: Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected.
propertyNames:
type: string
pattern: '^[!#$%&''*+.^_`|~0-9A-Za-z-]+$'
additionalProperties:
type: string
pattern: '^[^\r\n]*$'
ServiceTarget:
type: object
properties:
@@ -3087,8 +3053,6 @@ components:
enabled:
type: boolean
description: Whether this target is enabled
options:
$ref: '#/components/schemas/ServiceTargetOptions'
required:
- target_id
- target_type

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v6.33.0
// protoc v6.33.3
// source: management.proto
package proto

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,6 @@ package management;
option go_package = "/proto";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
// ProxyService - Management is the SERVER, Proxy is the CLIENT
@@ -51,22 +50,9 @@ enum ProxyMappingUpdateType {
UPDATE_TYPE_REMOVED = 2;
}
enum PathRewriteMode {
PATH_REWRITE_DEFAULT = 0;
PATH_REWRITE_PRESERVE = 1;
}
message PathTargetOptions {
bool skip_tls_verify = 1;
google.protobuf.Duration request_timeout = 2;
PathRewriteMode path_rewrite = 3;
map<string, string> custom_headers = 4;
}
message PathMapping {
string path = 1;
string target = 2;
PathTargetOptions options = 3;
}
message Authentication {
@@ -115,8 +101,6 @@ message AccessLog {
string auth_mechanism = 11;
string user_id = 12;
bool auth_success = 13;
int64 bytes_upload = 14;
int64 bytes_download = 15;
}
message AuthenticateRequest {