mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
[management, reverse proxy] Add reverse proxy feature (#5291)
* implement reverse proxy --------- Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk> Co-authored-by: mlsmaycon <mlsmaycon@gmail.com> Co-authored-by: Eduard Gert <kontakt@eduardgert.de> Co-authored-by: Viktor Liu <viktor@netbird.io> Co-authored-by: Diego Noguês <diego.sure@gmail.com> Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com> Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com> Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -99,6 +100,7 @@ type Account struct {
|
||||
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
|
||||
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"`
|
||||
// Settings is a dictionary of Account settings
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
|
||||
@@ -108,6 +110,8 @@ type Account struct {
|
||||
|
||||
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
|
||||
nmapInitOnce *sync.Once `gorm:"-"`
|
||||
|
||||
ReverseProxyFreeDomainNonce string
|
||||
}
|
||||
|
||||
func (a *Account) InitOnce() {
|
||||
@@ -902,6 +906,11 @@ func (a *Account) Copy() *Account {
|
||||
networkResources = append(networkResources, resource.Copy())
|
||||
}
|
||||
|
||||
services := []*reverseproxy.Service{}
|
||||
for _, service := range a.Services {
|
||||
services = append(services, service.Copy())
|
||||
}
|
||||
|
||||
return &Account{
|
||||
Id: a.Id,
|
||||
CreatedBy: a.CreatedBy,
|
||||
@@ -923,6 +932,7 @@ func (a *Account) Copy() *Account {
|
||||
Networks: nets,
|
||||
NetworkRouters: networkRouters,
|
||||
NetworkResources: networkResources,
|
||||
Services: services,
|
||||
Onboarding: a.Onboarding,
|
||||
NetworkMapCache: a.NetworkMapCache,
|
||||
nmapInitOnce: a.nmapInitOnce,
|
||||
@@ -1213,7 +1223,7 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
|
||||
for _, p := range uniquePeerIDs {
|
||||
peer, ok := a.Peers[p]
|
||||
if !ok || peer == nil {
|
||||
if !ok || peer == nil || peer.ProxyMeta.Embedded {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1776,6 +1786,110 @@ func (a *Account) GetActiveGroupUsers() map[string][]string {
|
||||
return groups
|
||||
}
|
||||
|
||||
func (a *Account) GetProxyPeers() map[string][]*nbpeer.Peer {
|
||||
proxyPeers := make(map[string][]*nbpeer.Peer)
|
||||
for _, peer := range a.Peers {
|
||||
if peer.ProxyMeta.Embedded {
|
||||
proxyPeers[peer.ProxyMeta.Cluster] = append(proxyPeers[peer.ProxyMeta.Cluster], peer)
|
||||
}
|
||||
}
|
||||
return proxyPeers
|
||||
}
|
||||
|
||||
func (a *Account) InjectProxyPolicies(ctx context.Context) {
|
||||
if len(a.Services) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
proxyPeersByCluster := a.GetProxyPeers()
|
||||
if len(proxyPeersByCluster) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, service := range a.Services {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
}
|
||||
a.injectServiceProxyPolicies(ctx, service, proxyPeersByCluster)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
a.injectTargetProxyPolicies(ctx, service, target, proxyPeersByCluster[service.ProxyCluster])
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) {
|
||||
port, ok := a.resolveTargetPort(ctx, target)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
path := ""
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
|
||||
for _, proxyPeer := range proxyPeers {
|
||||
policy := a.createProxyPolicy(service, target, proxyPeer, port, path)
|
||||
a.Policies = append(a.Policies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) {
|
||||
if target.Port != 0 {
|
||||
return target.Port, true
|
||||
}
|
||||
|
||||
switch target.Protocol {
|
||||
case "https":
|
||||
return 443, true
|
||||
case "http":
|
||||
return 80, true
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("unsupported protocol %s for proxy target %s, skipping policy injection", target.Protocol, target.TargetId)
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
|
||||
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
|
||||
return &Policy{
|
||||
ID: policyID,
|
||||
Name: fmt.Sprintf("Proxy Access to %s", service.Name),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: policyID,
|
||||
PolicyID: policyID,
|
||||
Name: fmt.Sprintf("Allow access to %s", service.Name),
|
||||
Enabled: true,
|
||||
SourceResource: Resource{
|
||||
ID: proxyPeer.ID,
|
||||
Type: ResourceTypePeer,
|
||||
},
|
||||
DestinationResource: Resource{
|
||||
ID: target.TargetId,
|
||||
Type: ResourceType(target.TargetType),
|
||||
},
|
||||
Bidirectional: false,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
PortRanges: []RulePortRange{
|
||||
{
|
||||
Start: uint16(port),
|
||||
End: uint16(port),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -70,7 +71,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -115,7 +116,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
||||
b.Run("old builder", func(b *testing.B) {
|
||||
for range b.N {
|
||||
for _, peerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -177,7 +178,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -240,7 +241,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -317,7 +318,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -402,7 +403,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -458,7 +459,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -537,7 +538,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
@@ -597,7 +598,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
||||
b.Run("old builder after delete", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
7
management/server/types/proxy.go
Normal file
7
management/server/types/proxy.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package types
|
||||
|
||||
// ProxyCallbackEndpoint holds the proxy callback endpoint
|
||||
const ProxyCallbackEndpoint = "/reverse-proxy/callback"
|
||||
|
||||
// ProxyCallbackEndpointFull holds the proxy callback endpoint with api suffix
|
||||
const ProxyCallbackEndpointFull = "/api" + ProxyCallbackEndpoint
|
||||
137
management/server/types/proxy_access_token.go
Normal file
137
management/server/types/proxy_access_token.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
b "github.com/hashicorp/go-secure-stdlib/base62"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// ProxyTokenPrefix is the globally used prefix for proxy access tokens
|
||||
ProxyTokenPrefix = "nbx_"
|
||||
// ProxyTokenSecretLength is the number of characters used for the secret
|
||||
ProxyTokenSecretLength = 30
|
||||
// ProxyTokenChecksumLength is the number of characters used for the encoded checksum
|
||||
ProxyTokenChecksumLength = 6
|
||||
// ProxyTokenLength is the total number of characters used for the token
|
||||
ProxyTokenLength = 40
|
||||
)
|
||||
|
||||
// HashedProxyToken is a SHA-256 hash of a plain proxy token, base64-encoded.
|
||||
type HashedProxyToken string
|
||||
|
||||
// PlainProxyToken is the raw token string displayed once at creation time.
|
||||
type PlainProxyToken string
|
||||
|
||||
// ProxyAccessToken holds information about a proxy access token including a hashed version for verification
|
||||
type ProxyAccessToken struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
Name string
|
||||
HashedToken HashedProxyToken `gorm:"type:varchar(255);uniqueIndex"`
|
||||
// AccountID is nil for management-wide tokens, set for account-scoped tokens
|
||||
AccountID *string `gorm:"index"`
|
||||
ExpiresAt *time.Time
|
||||
CreatedBy string
|
||||
CreatedAt time.Time
|
||||
LastUsed *time.Time
|
||||
Revoked bool
|
||||
}
|
||||
|
||||
// IsExpired returns true if the token has expired
|
||||
func (t *ProxyAccessToken) IsExpired() bool {
|
||||
if t.ExpiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(*t.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsValid returns true if the token is not revoked and not expired
|
||||
func (t *ProxyAccessToken) IsValid() bool {
|
||||
return !t.Revoked && !t.IsExpired()
|
||||
}
|
||||
|
||||
// ProxyAccessTokenGenerated holds the new token and the plain text version
|
||||
type ProxyAccessTokenGenerated struct {
|
||||
PlainToken PlainProxyToken
|
||||
ProxyAccessToken
|
||||
}
|
||||
|
||||
// CreateNewProxyAccessToken generates a new proxy access token.
|
||||
// Returns the token with hashed value stored and plain token for one-time display.
|
||||
func CreateNewProxyAccessToken(name string, expiresIn time.Duration, accountID *string, createdBy string) (*ProxyAccessTokenGenerated, error) {
|
||||
hashedToken, plainToken, err := generateProxyToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentTime := time.Now().UTC()
|
||||
var expiresAt *time.Time
|
||||
if expiresIn > 0 {
|
||||
expiresAt = util.ToPtr(currentTime.Add(expiresIn))
|
||||
}
|
||||
|
||||
return &ProxyAccessTokenGenerated{
|
||||
ProxyAccessToken: ProxyAccessToken{
|
||||
ID: xid.New().String(),
|
||||
Name: name,
|
||||
HashedToken: hashedToken,
|
||||
AccountID: accountID,
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: currentTime,
|
||||
Revoked: false,
|
||||
},
|
||||
PlainToken: plainToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateProxyToken() (HashedProxyToken, PlainProxyToken, error) {
|
||||
secret, err := b.Random(ProxyTokenSecretLength)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
checksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
encodedChecksum := base62.Encode(checksum)
|
||||
paddedChecksum := fmt.Sprintf("%06s", encodedChecksum)
|
||||
plainToken := PlainProxyToken(ProxyTokenPrefix + secret + paddedChecksum)
|
||||
return plainToken.Hash(), plainToken, nil
|
||||
}
|
||||
|
||||
// Hash returns the SHA-256 hash of the plain token, base64-encoded.
|
||||
func (t PlainProxyToken) Hash() HashedProxyToken {
|
||||
h := sha256.Sum256([]byte(t))
|
||||
return HashedProxyToken(base64.StdEncoding.EncodeToString(h[:]))
|
||||
}
|
||||
|
||||
// Validate checks the format of a proxy token without checking the database.
|
||||
func (t PlainProxyToken) Validate() error {
|
||||
if !strings.HasPrefix(string(t), ProxyTokenPrefix) {
|
||||
return fmt.Errorf("invalid token prefix")
|
||||
}
|
||||
|
||||
if len(t) != ProxyTokenLength {
|
||||
return fmt.Errorf("invalid token length")
|
||||
}
|
||||
|
||||
secret := t[len(ProxyTokenPrefix) : len(t)-ProxyTokenChecksumLength]
|
||||
checksumStr := t[len(t)-ProxyTokenChecksumLength:]
|
||||
|
||||
expectedChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
expectedChecksumStr := fmt.Sprintf("%06s", base62.Encode(expectedChecksum))
|
||||
|
||||
if string(checksumStr) != expectedChecksumStr {
|
||||
return fmt.Errorf("invalid token checksum")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
155
management/server/types/proxy_access_token_test.go
Normal file
155
management/server/types/proxy_access_token_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPlainProxyToken_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token PlainProxyToken
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid token",
|
||||
token: "", // will be generated
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "wrong prefix",
|
||||
token: "xyz_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM",
|
||||
wantErr: true,
|
||||
errMsg: "invalid token prefix",
|
||||
},
|
||||
{
|
||||
name: "too short",
|
||||
token: "nbx_short",
|
||||
wantErr: true,
|
||||
errMsg: "invalid token length",
|
||||
},
|
||||
{
|
||||
name: "too long",
|
||||
token: "nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNMextra",
|
||||
wantErr: true,
|
||||
errMsg: "invalid token length",
|
||||
},
|
||||
{
|
||||
name: "correct length but invalid checksum",
|
||||
token: "nbx_invalidtoken123456789012345678901234", // exactly 40 chars, invalid checksum
|
||||
wantErr: true,
|
||||
errMsg: "invalid token checksum",
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
wantErr: true,
|
||||
errMsg: "invalid token prefix",
|
||||
},
|
||||
{
|
||||
name: "only prefix",
|
||||
token: "nbx_",
|
||||
wantErr: true,
|
||||
errMsg: "invalid token length",
|
||||
},
|
||||
}
|
||||
|
||||
// Generate a valid token for the first test
|
||||
generated, err := CreateNewProxyAccessToken("test", 0, nil, "test")
|
||||
require.NoError(t, err)
|
||||
tests[0].token = generated.PlainToken
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.token.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlainProxyToken_Hash(t *testing.T) {
|
||||
token1 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM")
|
||||
token2 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM")
|
||||
token3 := PlainProxyToken("nbx_differenttoken1234567890123456789X")
|
||||
|
||||
hash1 := token1.Hash()
|
||||
hash2 := token2.Hash()
|
||||
hash3 := token3.Hash()
|
||||
|
||||
assert.Equal(t, hash1, hash2, "same token should produce same hash")
|
||||
assert.NotEqual(t, hash1, hash3, "different tokens should produce different hashes")
|
||||
assert.NotEmpty(t, hash1)
|
||||
}
|
||||
|
||||
func TestCreateNewProxyAccessToken(t *testing.T) {
|
||||
t.Run("creates valid token", func(t *testing.T) {
|
||||
generated, err := CreateNewProxyAccessToken("test-token", 0, nil, "test-user")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, generated.ID)
|
||||
assert.Equal(t, "test-token", generated.Name)
|
||||
assert.Equal(t, "test-user", generated.CreatedBy)
|
||||
assert.NotEmpty(t, generated.HashedToken)
|
||||
assert.NotEmpty(t, generated.PlainToken)
|
||||
assert.Nil(t, generated.ExpiresAt)
|
||||
assert.False(t, generated.Revoked)
|
||||
|
||||
assert.NoError(t, generated.PlainToken.Validate())
|
||||
assert.Equal(t, ProxyTokenLength, len(generated.PlainToken))
|
||||
assert.Equal(t, ProxyTokenPrefix, string(generated.PlainToken[:len(ProxyTokenPrefix)]))
|
||||
})
|
||||
|
||||
t.Run("tokens are unique", func(t *testing.T) {
|
||||
gen1, err := CreateNewProxyAccessToken("test1", 0, nil, "user")
|
||||
require.NoError(t, err)
|
||||
|
||||
gen2, err := CreateNewProxyAccessToken("test2", 0, nil, "user")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, gen1.PlainToken, gen2.PlainToken)
|
||||
assert.NotEqual(t, gen1.HashedToken, gen2.HashedToken)
|
||||
assert.NotEqual(t, gen1.ID, gen2.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyAccessToken_IsExpired(t *testing.T) {
|
||||
past := time.Now().Add(-1 * time.Hour)
|
||||
future := time.Now().Add(1 * time.Hour)
|
||||
|
||||
t.Run("expired token", func(t *testing.T) {
|
||||
token := &ProxyAccessToken{ExpiresAt: &past}
|
||||
assert.True(t, token.IsExpired())
|
||||
})
|
||||
|
||||
t.Run("not expired token", func(t *testing.T) {
|
||||
token := &ProxyAccessToken{ExpiresAt: &future}
|
||||
assert.False(t, token.IsExpired())
|
||||
})
|
||||
|
||||
t.Run("no expiration", func(t *testing.T) {
|
||||
token := &ProxyAccessToken{ExpiresAt: nil}
|
||||
assert.False(t, token.IsExpired())
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyAccessToken_IsValid(t *testing.T) {
|
||||
token := &ProxyAccessToken{
|
||||
Revoked: false,
|
||||
}
|
||||
|
||||
assert.True(t, token.IsValid())
|
||||
|
||||
token.Revoked = true
|
||||
assert.False(t, token.IsValid())
|
||||
}
|
||||
Reference in New Issue
Block a user