[management, reverse proxy] Add reverse proxy feature (#5291)

* implement reverse proxy


---------

Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk>
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
Co-authored-by: Eduard Gert <kontakt@eduardgert.de>
Co-authored-by: Viktor Liu <viktor@netbird.io>
Co-authored-by: Diego Noguês <diego.sure@gmail.com>
Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
This commit is contained in:
Pascal Fischer
2026-02-13 19:37:43 +01:00
committed by GitHub
parent edce11b34d
commit f53155562f
225 changed files with 35513 additions and 235 deletions

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -99,6 +100,7 @@ type Account struct {
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
@@ -108,6 +110,8 @@ type Account struct {
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
nmapInitOnce *sync.Once `gorm:"-"`
ReverseProxyFreeDomainNonce string
}
func (a *Account) InitOnce() {
@@ -902,6 +906,11 @@ func (a *Account) Copy() *Account {
networkResources = append(networkResources, resource.Copy())
}
services := []*reverseproxy.Service{}
for _, service := range a.Services {
services = append(services, service.Copy())
}
return &Account{
Id: a.Id,
CreatedBy: a.CreatedBy,
@@ -923,6 +932,7 @@ func (a *Account) Copy() *Account {
Networks: nets,
NetworkRouters: networkRouters,
NetworkResources: networkResources,
Services: services,
Onboarding: a.Onboarding,
NetworkMapCache: a.NetworkMapCache,
nmapInitOnce: a.nmapInitOnce,
@@ -1213,7 +1223,7 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
for _, p := range uniquePeerIDs {
peer, ok := a.Peers[p]
if !ok || peer == nil {
if !ok || peer == nil || peer.ProxyMeta.Embedded {
continue
}
@@ -1776,6 +1786,110 @@ func (a *Account) GetActiveGroupUsers() map[string][]string {
return groups
}
func (a *Account) GetProxyPeers() map[string][]*nbpeer.Peer {
proxyPeers := make(map[string][]*nbpeer.Peer)
for _, peer := range a.Peers {
if peer.ProxyMeta.Embedded {
proxyPeers[peer.ProxyMeta.Cluster] = append(proxyPeers[peer.ProxyMeta.Cluster], peer)
}
}
return proxyPeers
}
func (a *Account) InjectProxyPolicies(ctx context.Context) {
if len(a.Services) == 0 {
return
}
proxyPeersByCluster := a.GetProxyPeers()
if len(proxyPeersByCluster) == 0 {
return
}
for _, service := range a.Services {
if !service.Enabled {
continue
}
a.injectServiceProxyPolicies(ctx, service, proxyPeersByCluster)
}
}
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
for _, target := range service.Targets {
if !target.Enabled {
continue
}
a.injectTargetProxyPolicies(ctx, service, target, proxyPeersByCluster[service.ProxyCluster])
}
}
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) {
port, ok := a.resolveTargetPort(ctx, target)
if !ok {
return
}
path := ""
if target.Path != nil {
path = *target.Path
}
for _, proxyPeer := range proxyPeers {
policy := a.createProxyPolicy(service, target, proxyPeer, port, path)
a.Policies = append(a.Policies, policy)
}
}
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) {
if target.Port != 0 {
return target.Port, true
}
switch target.Protocol {
case "https":
return 443, true
case "http":
return 80, true
default:
log.WithContext(ctx).Warnf("unsupported protocol %s for proxy target %s, skipping policy injection", target.Protocol, target.TargetId)
return 0, false
}
}
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
return &Policy{
ID: policyID,
Name: fmt.Sprintf("Proxy Access to %s", service.Name),
Enabled: true,
Rules: []*PolicyRule{
{
ID: policyID,
PolicyID: policyID,
Name: fmt.Sprintf("Allow access to %s", service.Name),
Enabled: true,
SourceResource: Resource{
ID: proxyPeer.ID,
Type: ResourceTypePeer,
},
DestinationResource: Resource{
ID: target.TargetId,
Type: ResourceType(target.TargetType),
},
Bidirectional: false,
Protocol: PolicyRuleProtocolTCP,
Action: PolicyTrafficActionAccept,
PortRanges: []RulePortRange{
{
Start: uint16(port),
End: uint16(port),
},
},
},
},
}
}
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)

View File

@@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -70,7 +71,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -115,7 +116,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
b.Run("old builder", func(b *testing.B) {
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -177,7 +178,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -240,7 +241,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -317,7 +318,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -402,7 +403,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -458,7 +459,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -537,7 +538,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -597,7 +598,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})

View File

@@ -0,0 +1,7 @@
package types
// ProxyCallbackEndpoint holds the proxy callback endpoint
const ProxyCallbackEndpoint = "/reverse-proxy/callback"
// ProxyCallbackEndpointFull holds the proxy callback endpoint with api suffix
const ProxyCallbackEndpointFull = "/api" + ProxyCallbackEndpoint

View File

@@ -0,0 +1,137 @@
package types
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"hash/crc32"
"strings"
"time"
b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62"
"github.com/netbirdio/netbird/management/server/util"
)
const (
// ProxyTokenPrefix is the globally used prefix for proxy access tokens
ProxyTokenPrefix = "nbx_"
// ProxyTokenSecretLength is the number of characters used for the secret
ProxyTokenSecretLength = 30
// ProxyTokenChecksumLength is the number of characters used for the encoded checksum
ProxyTokenChecksumLength = 6
// ProxyTokenLength is the total number of characters used for the token
ProxyTokenLength = 40
)
// HashedProxyToken is a SHA-256 hash of a plain proxy token, base64-encoded.
type HashedProxyToken string
// PlainProxyToken is the raw token string displayed once at creation time.
type PlainProxyToken string
// ProxyAccessToken holds information about a proxy access token including a hashed version for verification
type ProxyAccessToken struct {
ID string `gorm:"primaryKey"`
Name string
HashedToken HashedProxyToken `gorm:"type:varchar(255);uniqueIndex"`
// AccountID is nil for management-wide tokens, set for account-scoped tokens
AccountID *string `gorm:"index"`
ExpiresAt *time.Time
CreatedBy string
CreatedAt time.Time
LastUsed *time.Time
Revoked bool
}
// IsExpired returns true if the token has expired
func (t *ProxyAccessToken) IsExpired() bool {
if t.ExpiresAt == nil {
return false
}
return time.Now().After(*t.ExpiresAt)
}
// IsValid returns true if the token is not revoked and not expired
func (t *ProxyAccessToken) IsValid() bool {
return !t.Revoked && !t.IsExpired()
}
// ProxyAccessTokenGenerated holds the new token and the plain text version
type ProxyAccessTokenGenerated struct {
PlainToken PlainProxyToken
ProxyAccessToken
}
// CreateNewProxyAccessToken generates a new proxy access token.
// Returns the token with hashed value stored and plain token for one-time display.
func CreateNewProxyAccessToken(name string, expiresIn time.Duration, accountID *string, createdBy string) (*ProxyAccessTokenGenerated, error) {
hashedToken, plainToken, err := generateProxyToken()
if err != nil {
return nil, err
}
currentTime := time.Now().UTC()
var expiresAt *time.Time
if expiresIn > 0 {
expiresAt = util.ToPtr(currentTime.Add(expiresIn))
}
return &ProxyAccessTokenGenerated{
ProxyAccessToken: ProxyAccessToken{
ID: xid.New().String(),
Name: name,
HashedToken: hashedToken,
AccountID: accountID,
ExpiresAt: expiresAt,
CreatedBy: createdBy,
CreatedAt: currentTime,
Revoked: false,
},
PlainToken: plainToken,
}, nil
}
func generateProxyToken() (HashedProxyToken, PlainProxyToken, error) {
secret, err := b.Random(ProxyTokenSecretLength)
if err != nil {
return "", "", err
}
checksum := crc32.ChecksumIEEE([]byte(secret))
encodedChecksum := base62.Encode(checksum)
paddedChecksum := fmt.Sprintf("%06s", encodedChecksum)
plainToken := PlainProxyToken(ProxyTokenPrefix + secret + paddedChecksum)
return plainToken.Hash(), plainToken, nil
}
// Hash returns the SHA-256 hash of the plain token, base64-encoded.
func (t PlainProxyToken) Hash() HashedProxyToken {
h := sha256.Sum256([]byte(t))
return HashedProxyToken(base64.StdEncoding.EncodeToString(h[:]))
}
// Validate checks the format of a proxy token without checking the database.
func (t PlainProxyToken) Validate() error {
if !strings.HasPrefix(string(t), ProxyTokenPrefix) {
return fmt.Errorf("invalid token prefix")
}
if len(t) != ProxyTokenLength {
return fmt.Errorf("invalid token length")
}
secret := t[len(ProxyTokenPrefix) : len(t)-ProxyTokenChecksumLength]
checksumStr := t[len(t)-ProxyTokenChecksumLength:]
expectedChecksum := crc32.ChecksumIEEE([]byte(secret))
expectedChecksumStr := fmt.Sprintf("%06s", base62.Encode(expectedChecksum))
if string(checksumStr) != expectedChecksumStr {
return fmt.Errorf("invalid token checksum")
}
return nil
}

View File

@@ -0,0 +1,155 @@
package types
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPlainProxyToken_Validate(t *testing.T) {
tests := []struct {
name string
token PlainProxyToken
wantErr bool
errMsg string
}{
{
name: "valid token",
token: "", // will be generated
wantErr: false,
},
{
name: "wrong prefix",
token: "xyz_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM",
wantErr: true,
errMsg: "invalid token prefix",
},
{
name: "too short",
token: "nbx_short",
wantErr: true,
errMsg: "invalid token length",
},
{
name: "too long",
token: "nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNMextra",
wantErr: true,
errMsg: "invalid token length",
},
{
name: "correct length but invalid checksum",
token: "nbx_invalidtoken123456789012345678901234", // exactly 40 chars, invalid checksum
wantErr: true,
errMsg: "invalid token checksum",
},
{
name: "empty token",
token: "",
wantErr: true,
errMsg: "invalid token prefix",
},
{
name: "only prefix",
token: "nbx_",
wantErr: true,
errMsg: "invalid token length",
},
}
// Generate a valid token for the first test
generated, err := CreateNewProxyAccessToken("test", 0, nil, "test")
require.NoError(t, err)
tests[0].token = generated.PlainToken
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.token.Validate()
if tt.wantErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestPlainProxyToken_Hash(t *testing.T) {
token1 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM")
token2 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM")
token3 := PlainProxyToken("nbx_differenttoken1234567890123456789X")
hash1 := token1.Hash()
hash2 := token2.Hash()
hash3 := token3.Hash()
assert.Equal(t, hash1, hash2, "same token should produce same hash")
assert.NotEqual(t, hash1, hash3, "different tokens should produce different hashes")
assert.NotEmpty(t, hash1)
}
func TestCreateNewProxyAccessToken(t *testing.T) {
t.Run("creates valid token", func(t *testing.T) {
generated, err := CreateNewProxyAccessToken("test-token", 0, nil, "test-user")
require.NoError(t, err)
assert.NotEmpty(t, generated.ID)
assert.Equal(t, "test-token", generated.Name)
assert.Equal(t, "test-user", generated.CreatedBy)
assert.NotEmpty(t, generated.HashedToken)
assert.NotEmpty(t, generated.PlainToken)
assert.Nil(t, generated.ExpiresAt)
assert.False(t, generated.Revoked)
assert.NoError(t, generated.PlainToken.Validate())
assert.Equal(t, ProxyTokenLength, len(generated.PlainToken))
assert.Equal(t, ProxyTokenPrefix, string(generated.PlainToken[:len(ProxyTokenPrefix)]))
})
t.Run("tokens are unique", func(t *testing.T) {
gen1, err := CreateNewProxyAccessToken("test1", 0, nil, "user")
require.NoError(t, err)
gen2, err := CreateNewProxyAccessToken("test2", 0, nil, "user")
require.NoError(t, err)
assert.NotEqual(t, gen1.PlainToken, gen2.PlainToken)
assert.NotEqual(t, gen1.HashedToken, gen2.HashedToken)
assert.NotEqual(t, gen1.ID, gen2.ID)
})
}
func TestProxyAccessToken_IsExpired(t *testing.T) {
past := time.Now().Add(-1 * time.Hour)
future := time.Now().Add(1 * time.Hour)
t.Run("expired token", func(t *testing.T) {
token := &ProxyAccessToken{ExpiresAt: &past}
assert.True(t, token.IsExpired())
})
t.Run("not expired token", func(t *testing.T) {
token := &ProxyAccessToken{ExpiresAt: &future}
assert.False(t, token.IsExpired())
})
t.Run("no expiration", func(t *testing.T) {
token := &ProxyAccessToken{ExpiresAt: nil}
assert.False(t, token.IsExpired())
})
}
func TestProxyAccessToken_IsValid(t *testing.T) {
token := &ProxyAccessToken{
Revoked: false,
}
assert.True(t, token.IsValid())
token.Revoked = true
assert.False(t, token.IsValid())
}