mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-20 07:39:56 +00:00
allow redis for token store
This commit is contained in:
@@ -188,7 +188,10 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
|
||||
|
||||
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||
return Create(s, func() *nbgrpc.OneTimeTokenStore {
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create proxy token store: %v", err)
|
||||
}
|
||||
log.Info("One-time token store initialized for proxy authentication")
|
||||
return tokenStore
|
||||
})
|
||||
|
||||
@@ -1,28 +1,21 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
)
|
||||
|
||||
// OneTimeTokenStore manages short-lived, single-use authentication tokens
|
||||
// for proxy-to-management RPC authentication. Tokens are generated when
|
||||
// a service is created and must be used exactly once by the proxy
|
||||
// to authenticate a subsequent RPC call.
|
||||
type OneTimeTokenStore struct {
|
||||
tokens map[string]*tokenMetadata
|
||||
mu sync.RWMutex
|
||||
cleanup *time.Ticker
|
||||
cleanupDone chan struct{}
|
||||
}
|
||||
|
||||
// tokenMetadata stores information about a one-time token
|
||||
type tokenMetadata struct {
|
||||
ServiceID string
|
||||
AccountID string
|
||||
@@ -30,20 +23,24 @@ type tokenMetadata struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// NewOneTimeTokenStore creates a new token store with automatic cleanup
|
||||
// of expired tokens. The cleanupInterval determines how often expired
|
||||
// tokens are removed from memory.
|
||||
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||
store := &OneTimeTokenStore{
|
||||
tokens: make(map[string]*tokenMetadata),
|
||||
cleanup: time.NewTicker(cleanupInterval),
|
||||
cleanupDone: make(chan struct{}),
|
||||
// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
|
||||
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
||||
type OneTimeTokenStore struct {
|
||||
store store.StoreInterface
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewOneTimeTokenStore creates a token store with automatic backend selection
|
||||
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
|
||||
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cache store: %w", err)
|
||||
}
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go store.cleanupExpired()
|
||||
|
||||
return store
|
||||
return &OneTimeTokenStore{
|
||||
store: cacheStore,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateToken creates a new cryptographically secure one-time token
|
||||
@@ -52,25 +49,25 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||
//
|
||||
// Returns the generated token string or an error if random generation fails.
|
||||
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
|
||||
// Generate 32 bytes (256 bits) of cryptographically secure random data
|
||||
randomBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||
}
|
||||
|
||||
// Encode as URL-safe base64 for easy transmission in gRPC
|
||||
token := base64.URLEncoding.EncodeToString(randomBytes)
|
||||
hashedToken := hashToken(token)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.tokens[token] = &tokenMetadata{
|
||||
metadata := &tokenMetadata{
|
||||
ServiceID: serviceID,
|
||||
AccountID: accountID,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.store.Set(s.ctx, hashedToken, metadata, store.WithExpiration(ttl)); err != nil {
|
||||
return "", fmt.Errorf("failed to store token: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
|
||||
serviceID, accountID, ttl)
|
||||
|
||||
@@ -88,80 +85,46 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
|
||||
// - Account ID doesn't match
|
||||
// - Reverse proxy ID doesn't match
|
||||
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
hashedToken := hashToken(token)
|
||||
|
||||
metadata, exists := s.tokens[token]
|
||||
if !exists {
|
||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
|
||||
serviceID, accountID)
|
||||
value, err := s.store.Get(s.ctx, hashedToken)
|
||||
if err != nil {
|
||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
|
||||
return fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
metadata, ok := value.(*tokenMetadata)
|
||||
if !ok {
|
||||
log.Warnf("Token validation failed: invalid metadata type (proxy: %s, account: %s)", serviceID, accountID)
|
||||
return fmt.Errorf("invalid token metadata")
|
||||
}
|
||||
|
||||
if time.Now().After(metadata.ExpiresAt) {
|
||||
delete(s.tokens, token)
|
||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
|
||||
serviceID, accountID)
|
||||
s.store.Delete(s.ctx, hashedToken)
|
||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
// Validate account ID using constant-time comparison (prevents timing attacks)
|
||||
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
|
||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
|
||||
metadata.AccountID, accountID)
|
||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
|
||||
return fmt.Errorf("account ID mismatch")
|
||||
}
|
||||
|
||||
// Validate service ID using constant-time comparison
|
||||
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
|
||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
|
||||
metadata.ServiceID, serviceID)
|
||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
|
||||
return fmt.Errorf("service ID mismatch")
|
||||
}
|
||||
|
||||
// Delete token immediately to enforce single-use
|
||||
delete(s.tokens, token)
|
||||
if err := s.store.Delete(s.ctx, hashedToken); err != nil {
|
||||
log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||
}
|
||||
|
||||
log.Infof("Token validated and consumed for proxy %s in account %s",
|
||||
serviceID, accountID)
|
||||
log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupExpired removes expired tokens in the background to prevent memory leaks
|
||||
func (s *OneTimeTokenStore) cleanupExpired() {
|
||||
for {
|
||||
select {
|
||||
case <-s.cleanup.C:
|
||||
s.mu.Lock()
|
||||
now := time.Now()
|
||||
removed := 0
|
||||
for token, metadata := range s.tokens {
|
||||
if now.After(metadata.ExpiresAt) {
|
||||
delete(s.tokens, token)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
if removed > 0 {
|
||||
log.Debugf("Cleaned up %d expired one-time tokens", removed)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
case <-s.cleanupDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup goroutine and releases resources
|
||||
func (s *OneTimeTokenStore) Close() {
|
||||
s.cleanup.Stop()
|
||||
close(s.cleanupDone)
|
||||
}
|
||||
|
||||
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
|
||||
func (s *OneTimeTokenStore) GetTokenCount() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.tokens)
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
@@ -486,33 +486,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableClusters returns information about all connected proxy clusters.
|
||||
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
|
||||
clusterCounts := make(map[string]int)
|
||||
s.clusterProxies.Range(func(key, value interface{}) bool {
|
||||
clusterAddr := key.(string)
|
||||
proxySet := value.(*sync.Map)
|
||||
count := 0
|
||||
proxySet.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
clusterCounts[clusterAddr] = count
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
clusters := make([]ClusterInfo, 0, len(clusterCounts))
|
||||
for addr, count := range clusterCounts {
|
||||
clusters = append(clusters, ClusterInfo{
|
||||
Address: addr,
|
||||
ConnectedProxies: count,
|
||||
})
|
||||
}
|
||||
return clusters
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
@@ -41,8 +42,8 @@ func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
@@ -96,8 +97,8 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
@@ -131,8 +132,8 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
|
||||
@@ -37,7 +37,10 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
proxyManager := &testValidateSessionProxyManager{store: testStore}
|
||||
usersManager := &testValidateSessionUsersManager{store: testStore}
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager)
|
||||
proxyService.SetProxyManager(proxyManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
|
||||
Reference in New Issue
Block a user