mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 19:26:39 +00:00
[management] Store connected proxies in DB (#5472)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
This commit is contained in:
@@ -1,28 +1,23 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/lib/v4/cache"
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
)
|
||||
|
||||
// OneTimeTokenStore manages short-lived, single-use authentication tokens
|
||||
// for proxy-to-management RPC authentication. Tokens are generated when
|
||||
// a service is created and must be used exactly once by the proxy
|
||||
// to authenticate a subsequent RPC call.
|
||||
type OneTimeTokenStore struct {
|
||||
tokens map[string]*tokenMetadata
|
||||
mu sync.RWMutex
|
||||
cleanup *time.Ticker
|
||||
cleanupDone chan struct{}
|
||||
}
|
||||
|
||||
// tokenMetadata stores information about a one-time token
|
||||
type tokenMetadata struct {
|
||||
ServiceID string
|
||||
AccountID string
|
||||
@@ -30,20 +25,24 @@ type tokenMetadata struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// NewOneTimeTokenStore creates a new token store with automatic cleanup
|
||||
// of expired tokens. The cleanupInterval determines how often expired
|
||||
// tokens are removed from memory.
|
||||
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||
store := &OneTimeTokenStore{
|
||||
tokens: make(map[string]*tokenMetadata),
|
||||
cleanup: time.NewTicker(cleanupInterval),
|
||||
cleanupDone: make(chan struct{}),
|
||||
// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
|
||||
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
||||
type OneTimeTokenStore struct {
|
||||
cache *cache.Cache[string]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewOneTimeTokenStore creates a token store with automatic backend selection
|
||||
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
|
||||
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cache store: %w", err)
|
||||
}
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go store.cleanupExpired()
|
||||
|
||||
return store
|
||||
return &OneTimeTokenStore{
|
||||
cache: cache.New[string](cacheStore),
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateToken creates a new cryptographically secure one-time token
|
||||
@@ -52,25 +51,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||
//
|
||||
// Returns the generated token string or an error if random generation fails.
|
||||
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
|
||||
// Generate 32 bytes (256 bits) of cryptographically secure random data
|
||||
randomBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||
}
|
||||
|
||||
// Encode as URL-safe base64 for easy transmission in gRPC
|
||||
token := base64.URLEncoding.EncodeToString(randomBytes)
|
||||
hashedToken := hashToken(token)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.tokens[token] = &tokenMetadata{
|
||||
metadata := &tokenMetadata{
|
||||
ServiceID: serviceID,
|
||||
AccountID: accountID,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
metadataJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
|
||||
}
|
||||
|
||||
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
|
||||
return "", fmt.Errorf("failed to store token: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
|
||||
serviceID, accountID, ttl)
|
||||
|
||||
@@ -88,80 +92,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
|
||||
// - Account ID doesn't match
|
||||
// - Reverse proxy ID doesn't match
|
||||
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
hashedToken := hashToken(token)
|
||||
|
||||
metadata, exists := s.tokens[token]
|
||||
if !exists {
|
||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
|
||||
serviceID, accountID)
|
||||
metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
|
||||
if err != nil {
|
||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
|
||||
return fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
metadata := &tokenMetadata{}
|
||||
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
|
||||
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||
return fmt.Errorf("invalid token metadata")
|
||||
}
|
||||
|
||||
if time.Now().After(metadata.ExpiresAt) {
|
||||
delete(s.tokens, token)
|
||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
|
||||
serviceID, accountID)
|
||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
// Validate account ID using constant-time comparison (prevents timing attacks)
|
||||
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
|
||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
|
||||
metadata.AccountID, accountID)
|
||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
|
||||
return fmt.Errorf("account ID mismatch")
|
||||
}
|
||||
|
||||
// Validate service ID using constant-time comparison
|
||||
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
|
||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
|
||||
metadata.ServiceID, serviceID)
|
||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
|
||||
return fmt.Errorf("service ID mismatch")
|
||||
}
|
||||
|
||||
// Delete token immediately to enforce single-use
|
||||
delete(s.tokens, token)
|
||||
if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
|
||||
log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||
}
|
||||
|
||||
log.Infof("Token validated and consumed for proxy %s in account %s",
|
||||
serviceID, accountID)
|
||||
log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupExpired removes expired tokens in the background to prevent memory leaks
|
||||
func (s *OneTimeTokenStore) cleanupExpired() {
|
||||
for {
|
||||
select {
|
||||
case <-s.cleanup.C:
|
||||
s.mu.Lock()
|
||||
now := time.Now()
|
||||
removed := 0
|
||||
for token, metadata := range s.tokens {
|
||||
if now.After(metadata.ExpiresAt) {
|
||||
delete(s.tokens, token)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
if removed > 0 {
|
||||
log.Debugf("Cleaned up %d expired one-time tokens", removed)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
case <-s.cleanupDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup goroutine and releases resources
|
||||
func (s *OneTimeTokenStore) Close() {
|
||||
s.cleanup.Stop()
|
||||
close(s.cleanupDone)
|
||||
}
|
||||
|
||||
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
|
||||
func (s *OneTimeTokenStore) GetTokenCount() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.tokens)
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user