Add proxy <-> management authentication

This commit is contained in:
Viktor Liu
2026-02-06 02:06:24 +08:00
parent 2f263bf7e6
commit 0a3a9f977d
17 changed files with 1256 additions and 17 deletions

View File

@@ -80,4 +80,10 @@ func init() {
migrationCmd.AddCommand(upCmd)
rootCmd.AddCommand(migrationCmd)
tokenCmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
tokenCmd.AddCommand(tokenCreateCmd)
tokenCmd.AddCommand(tokenListCmd)
tokenCmd.AddCommand(tokenRevokeCmd)
rootCmd.AddCommand(tokenCmd)
}

208
management/cmd/token.go Normal file
View File

@@ -0,0 +1,208 @@
package cmd
import (
"context"
"fmt"
"os"
"strconv"
"text/tabwriter"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
var (
tokenName string
tokenExpireIn string
tokenDatadir string
tokenCmd = &cobra.Command{
Use: "token",
Short: "Manage proxy access tokens",
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
}
tokenCreateCmd = &cobra.Command{
Use: "create",
Short: "Create a new proxy access token",
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
RunE: tokenCreateRun,
}
tokenListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all proxy access tokens",
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
RunE: tokenListRun,
}
tokenRevokeCmd = &cobra.Command{
Use: "revoke [token-id]",
Short: "Revoke a proxy access token",
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
Args: cobra.ExactArgs(1),
RunE: tokenRevokeRun,
}
)
func init() {
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
tokenCreateCmd.MarkFlagRequired("name") //nolint
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if tokenDatadir != "" {
datadir = tokenDatadir
}
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
expiresIn, err := parseDuration(tokenExpireIn)
if err != nil {
return fmt.Errorf("parse expiration: %w", err)
}
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
return fmt.Errorf("save token: %w", err)
}
fmt.Println("Token created successfully!")
fmt.Printf("Token: %s\n", generated.PlainToken)
fmt.Println()
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.")
fmt.Printf("Token ID: %s\n", generated.ID)
return nil
})
}
func tokenListRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
if err != nil {
return fmt.Errorf("list tokens: %w", err)
}
if len(tokens) == 0 {
fmt.Println("No proxy access tokens found.")
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
for _, t := range tokens {
expires := "never"
if t.ExpiresAt != nil {
expires = t.ExpiresAt.Format("2006-01-02")
}
lastUsed := "never"
if t.LastUsed != nil {
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
}
revoked := "no"
if t.Revoked {
revoked = "yes"
}
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
t.ID,
t.Name,
t.CreatedAt.Format("2006-01-02"),
expires,
lastUsed,
revoked,
)
}
w.Flush()
return nil
})
}
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokenID := args[0]
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
return fmt.Errorf("revoke token: %w", err)
}
fmt.Printf("Token %s revoked successfully.\n", tokenID)
return nil
})
}
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
// An empty string returns zero duration (no expiration).
func parseDuration(s string) (time.Duration, error) {
if len(s) == 0 {
return 0, nil
}
if s[len(s)-1] == 'd' {
d, err := strconv.Atoi(s[:len(s)-1])
if err != nil {
return 0, fmt.Errorf("invalid day format: %s", s)
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return time.Duration(d) * 24 * time.Hour, nil
}
d, err := time.ParseDuration(s)
if err != nil {
return 0, err
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return d, nil
}

View File

@@ -0,0 +1,101 @@
package cmd
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseDuration(t *testing.T) {
tests := []struct {
name string
input string
expected time.Duration
wantErr bool
}{
{
name: "empty string returns zero",
input: "",
expected: 0,
},
{
name: "days suffix",
input: "30d",
expected: 30 * 24 * time.Hour,
},
{
name: "one day",
input: "1d",
expected: 24 * time.Hour,
},
{
name: "365 days",
input: "365d",
expected: 365 * 24 * time.Hour,
},
{
name: "hours via Go duration",
input: "24h",
expected: 24 * time.Hour,
},
{
name: "minutes via Go duration",
input: "30m",
expected: 30 * time.Minute,
},
{
name: "complex Go duration",
input: "1h30m",
expected: 90 * time.Minute,
},
{
name: "invalid day format",
input: "abcd",
wantErr: true,
},
{
name: "negative days",
input: "-1d",
wantErr: true,
},
{
name: "zero days",
input: "0d",
wantErr: true,
},
{
name: "non-numeric days",
input: "xyzd",
wantErr: true,
},
{
name: "negative Go duration",
input: "-24h",
wantErr: true,
},
{
name: "zero Go duration",
input: "0s",
wantErr: true,
},
{
name: "invalid Go duration",
input: "notaduration",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseDuration(tt.input)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -123,11 +123,13 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
realip.WithTrustedProxiesCount(trustedProxiesCount),
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
}
proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store())
s.proxyAuthClose = proxyAuthClose
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream),
}
if s.Config.HttpConfig.LetsEncryptDomain != "" {

View File

@@ -58,6 +58,8 @@ type BaseServer struct {
mgmtMetricsPort int
mgmtPort int
proxyAuthClose func()
listener net.Listener
certManager *autocert.Manager
update *version.Update
@@ -215,6 +217,10 @@ func (s *BaseServer) Stop() error {
_ = s.certManager.Listener().Close()
}
s.GRPCServer().Stop()
if s.proxyAuthClose != nil {
s.proxyAuthClose()
s.proxyAuthClose = nil
}
_ = s.Store().Close(ctx)
_ = s.EventStore().Close(ctx)
if s.update != nil {

View File

@@ -269,19 +269,27 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
accessLog := req.GetLog()
log.WithFields(log.Fields{
fields := log.Fields{
"reverse_proxy_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(),
"host": accessLog.GetHost(),
"path": accessLog.GetPath(),
"method": accessLog.GetMethod(),
"response_code": accessLog.GetResponseCode(),
"duration_ms": accessLog.GetDurationMs(),
"source_ip": accessLog.GetSourceIp(),
"auth_mechanism": accessLog.GetAuthMechanism(),
"user_id": accessLog.GetUserId(),
"auth_success": accessLog.GetAuthSuccess(),
}).Debug("Access log from proxy")
}
if mechanism := accessLog.GetAuthMechanism(); mechanism != "" {
fields["auth_mechanism"] = mechanism
}
if userID := accessLog.GetUserId(); userID != "" {
fields["user_id"] = userID
}
if !accessLog.GetAuthSuccess() {
fields["auth_success"] = false
}
log.WithFields(fields).Debugf("%s %s %d (%dms)",
accessLog.GetMethod(),
accessLog.GetPath(),
accessLog.GetResponseCode(),
accessLog.GetDurationMs(),
)
logEntry := &accesslogs.AccessLogEntry{}
logEntry.FromProto(accessLog)

View File

@@ -0,0 +1,234 @@
package grpc
import (
"context"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const (
// lastUsedUpdateInterval is the minimum interval between last_used updates for the same token.
lastUsedUpdateInterval = time.Minute
// lastUsedCleanupInterval is how often stale lastUsed entries are removed.
lastUsedCleanupInterval = 2 * time.Minute
)
type proxyTokenContextKey struct{}
// ProxyTokenContextKey is the typed key used to store validated token info in context.
var ProxyTokenContextKey = proxyTokenContextKey{}
// proxyTokenID identifies a proxy access token by its database ID.
type proxyTokenID = string
// proxyTokenStore defines the store interface needed for token validation
type proxyTokenStore interface {
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength store.LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
}
// proxyAuthInterceptor holds state for proxy authentication interceptors.
type proxyAuthInterceptor struct {
store proxyTokenStore
failureLimiter *authFailureLimiter
// lastUsedMu protects lastUsedTimes
lastUsedMu sync.Mutex
lastUsedTimes map[proxyTokenID]time.Time
cancel context.CancelFunc
}
func newProxyAuthInterceptor(tokenStore proxyTokenStore) *proxyAuthInterceptor {
ctx, cancel := context.WithCancel(context.Background())
i := &proxyAuthInterceptor{
store: tokenStore,
failureLimiter: newAuthFailureLimiter(),
lastUsedTimes: make(map[proxyTokenID]time.Time),
cancel: cancel,
}
go i.lastUsedCleanupLoop(ctx)
return i
}
// NewProxyAuthInterceptors creates gRPC unary and stream interceptors that validate proxy access tokens.
// They only intercept ProxyService methods. Both interceptors share state for last-used and failure rate limiting.
// The returned close function must be called on shutdown to stop background goroutines.
func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor, func()) {
interceptor := newProxyAuthInterceptor(tokenStore)
unary := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
return handler(ctx, req)
}
token, err := interceptor.validateProxyToken(ctx)
if err != nil {
// Log auth failures explicitly; gRPC doesn't log these by default.
log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
return nil, err
}
ctx = context.WithValue(ctx, ProxyTokenContextKey, token)
return handler(ctx, req)
}
stream := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
return handler(srv, ss)
}
token, err := interceptor.validateProxyToken(ss.Context())
if err != nil {
// Log auth failures explicitly; gRPC doesn't log these by default.
log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
return err
}
ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token)
wrapped := &wrappedServerStream{
ServerStream: ss,
ctx: ctx,
}
return handler(srv, wrapped)
}
return unary, stream, interceptor.close
}
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
clientIP := peerIPFromContext(ctx)
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
}
token, err := i.doValidateProxyToken(ctx)
if err != nil {
if clientIP != "" {
i.failureLimiter.recordFailure(clientIP)
}
return nil, err
}
i.maybeUpdateLastUsed(ctx, token.ID)
return token, nil
}
func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
}
authValues := md.Get("authorization")
if len(authValues) == 0 {
return nil, status.Errorf(codes.Unauthenticated, "missing authorization header")
}
authValue := authValues[0]
if !strings.HasPrefix(authValue, "Bearer ") {
return nil, status.Errorf(codes.Unauthenticated, "invalid authorization format")
}
plainToken := types.PlainProxyToken(strings.TrimPrefix(authValue, "Bearer "))
if err := plainToken.Validate(); err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid token format")
}
token, err := i.store.GetProxyAccessTokenByHashedToken(ctx, store.LockingStrengthNone, plainToken.Hash())
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
// Currently tokens are management-wide; AccountID field is reserved for future use.
if !token.IsValid() {
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
}
return token, nil
}
// maybeUpdateLastUsed updates the last_used timestamp if enough time has passed since the last update.
func (i *proxyAuthInterceptor) maybeUpdateLastUsed(ctx context.Context, tokenID string) {
now := time.Now()
i.lastUsedMu.Lock()
lastUpdate, exists := i.lastUsedTimes[tokenID]
if exists && now.Sub(lastUpdate) < lastUsedUpdateInterval {
i.lastUsedMu.Unlock()
return
}
i.lastUsedTimes[tokenID] = now
i.lastUsedMu.Unlock()
if err := i.store.MarkProxyAccessTokenUsed(ctx, tokenID); err != nil {
log.WithContext(ctx).Debugf("failed to mark proxy token as used: %v", err)
}
}
func (i *proxyAuthInterceptor) lastUsedCleanupLoop(ctx context.Context) {
ticker := time.NewTicker(lastUsedCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
i.cleanupStaleLastUsed()
case <-ctx.Done():
return
}
}
}
// cleanupStaleLastUsed removes entries older than 2x the update interval.
func (i *proxyAuthInterceptor) cleanupStaleLastUsed() {
i.lastUsedMu.Lock()
defer i.lastUsedMu.Unlock()
now := time.Now()
staleThreshold := 2 * lastUsedUpdateInterval
for id, lastUpdate := range i.lastUsedTimes {
if now.Sub(lastUpdate) > staleThreshold {
delete(i.lastUsedTimes, id)
}
}
}
func (i *proxyAuthInterceptor) close() {
i.cancel()
i.failureLimiter.stop()
}
// GetProxyTokenFromContext retrieves the validated proxy token from the context
func GetProxyTokenFromContext(ctx context.Context) *types.ProxyAccessToken {
token, ok := ctx.Value(ProxyTokenContextKey).(*types.ProxyAccessToken)
if !ok {
return nil
}
return token
}
// wrappedServerStream wraps a grpc.ServerStream to provide a custom context
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}

View File

@@ -0,0 +1,134 @@
package grpc
import (
"context"
"net"
"sync"
"time"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"golang.org/x/time/rate"
"google.golang.org/grpc/peer"
)
const (
// proxyAuthFailureBurst is the maximum number of failed attempts before rate limiting kicks in.
proxyAuthFailureBurst = 5
// proxyAuthLimiterCleanup is how often stale limiters are removed.
proxyAuthLimiterCleanup = 5 * time.Minute
// proxyAuthLimiterTTL is how long a limiter is kept after the last failure.
proxyAuthLimiterTTL = 15 * time.Minute
)
// defaultProxyAuthFailureRate is the token replenishment rate for failed auth attempts.
// One token every 12 seconds = 5 per minute.
var defaultProxyAuthFailureRate = rate.Every(12 * time.Second)
// clientIP identifies a client by its IP address for rate limiting purposes.
type clientIP = string
type limiterEntry struct {
limiter *rate.Limiter
lastAccess time.Time
}
// authFailureLimiter tracks per-IP rate limits for failed proxy authentication attempts.
type authFailureLimiter struct {
mu sync.Mutex
limiters map[clientIP]*limiterEntry
failureRate rate.Limit
cancel context.CancelFunc
}
func newAuthFailureLimiter() *authFailureLimiter {
return newAuthFailureLimiterWithRate(defaultProxyAuthFailureRate)
}
func newAuthFailureLimiterWithRate(failureRate rate.Limit) *authFailureLimiter {
ctx, cancel := context.WithCancel(context.Background())
l := &authFailureLimiter{
limiters: make(map[clientIP]*limiterEntry),
failureRate: failureRate,
cancel: cancel,
}
go l.cleanupLoop(ctx)
return l
}
// isLimited returns true if the given IP has exhausted its failure budget.
func (l *authFailureLimiter) isLimited(ip clientIP) bool {
l.mu.Lock()
defer l.mu.Unlock()
entry, exists := l.limiters[ip]
if !exists {
return false
}
return entry.limiter.Tokens() < 1
}
// recordFailure consumes a token from the rate limiter for the given IP.
func (l *authFailureLimiter) recordFailure(ip clientIP) {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
entry, exists := l.limiters[ip]
if !exists {
entry = &limiterEntry{
limiter: rate.NewLimiter(l.failureRate, proxyAuthFailureBurst),
}
l.limiters[ip] = entry
}
entry.lastAccess = now
entry.limiter.Allow()
}
func (l *authFailureLimiter) cleanupLoop(ctx context.Context) {
ticker := time.NewTicker(proxyAuthLimiterCleanup)
defer ticker.Stop()
for {
select {
case <-ticker.C:
l.cleanup()
case <-ctx.Done():
return
}
}
}
func (l *authFailureLimiter) cleanup() {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
for ip, entry := range l.limiters {
if now.Sub(entry.lastAccess) > proxyAuthLimiterTTL {
delete(l.limiters, ip)
}
}
}
func (l *authFailureLimiter) stop() {
l.cancel()
}
// peerIPFromContext extracts the client IP from the gRPC context.
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
func peerIPFromContext(ctx context.Context) clientIP {
if addr, ok := realip.FromContext(ctx); ok {
return addr.String()
}
if p, ok := peer.FromContext(ctx); ok {
host, _, err := net.SplitHostPort(p.Addr.String())
if err != nil {
return p.Addr.String()
}
return host
}
return ""
}

View File

@@ -0,0 +1,98 @@
package grpc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
}
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
ip := "192.168.1.1"
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
}
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure("192.168.1.1")
}
assert.True(t, l.isLimited("192.168.1.1"))
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
}
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
defer l.stop()
ip := "10.0.0.1"
// Exhaust burst
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
require.True(t, l.isLimited(ip))
// Wait for token replenishment
time.Sleep(50 * time.Millisecond)
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
}
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.mu.Lock()
require.Len(t, l.limiters, 1)
// Backdate the entry so it looks stale
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
l.mu.Unlock()
}
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.recordFailure("10.0.0.2")
l.mu.Lock()
// Only backdate one entry
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
assert.Contains(t, l.limiters, "10.0.0.2")
l.mu.Unlock()
}

View File

@@ -126,7 +126,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.ProxyAccessToken{},
&types.Group{}, &types.GroupPeer{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
@@ -4309,6 +4310,79 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
return nil
}
// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value.
func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var token types.ProxyAccessToken
result := tx.Take(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "proxy access token not found")
}
return nil, status.Errorf(status.Internal, "get proxy access token: %v", result.Error)
}
return &token, nil
}
// GetAllProxyAccessTokens retrieves all proxy access tokens.
func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var tokens []*types.ProxyAccessToken
result := tx.Find(&tokens)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "get proxy access tokens: %v", result.Error)
}
return tokens, nil
}
// SaveProxyAccessToken saves a proxy access token to the database.
func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error {
if result := s.db.WithContext(ctx).Create(token); result.Error != nil {
return status.Errorf(status.Internal, "save proxy access token: %v", result.Error)
}
return nil
}
// RevokeProxyAccessToken revokes a proxy access token by its ID.
func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
if result.Error != nil {
return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "proxy access token not found")
}
return nil
}
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
Where(idQueryCondition, tokenID).
Update("last_used", time.Now().UTC())
if result.Error != nil {
return status.Errorf(status.Internal, "mark proxy access token as used: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "proxy access token not found")
}
return nil
}
func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {

View File

@@ -109,6 +109,12 @@ type Store interface {
SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error
DeletePAT(ctx context.Context, userID, patID string) error
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)

View File

@@ -0,0 +1,137 @@
package types
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"hash/crc32"
"strings"
"time"
b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62"
"github.com/netbirdio/netbird/management/server/util"
)
const (
// ProxyTokenPrefix is the globally used prefix for proxy access tokens
ProxyTokenPrefix = "nbx_"
// ProxyTokenSecretLength is the number of characters used for the secret
ProxyTokenSecretLength = 30
// ProxyTokenChecksumLength is the number of characters used for the encoded checksum
ProxyTokenChecksumLength = 6
// ProxyTokenLength is the total number of characters used for the token
ProxyTokenLength = 40
)
// HashedProxyToken is a SHA-256 hash of a plain proxy token, base64-encoded.
type HashedProxyToken string
// PlainProxyToken is the raw token string displayed once at creation time.
type PlainProxyToken string
// ProxyAccessToken holds information about a proxy access token including a hashed version for verification
type ProxyAccessToken struct {
ID string `gorm:"primaryKey"`
Name string
HashedToken HashedProxyToken `gorm:"uniqueIndex"`
// AccountID is nil for management-wide tokens, set for account-scoped tokens
AccountID *string `gorm:"index"`
ExpiresAt *time.Time
CreatedBy string
CreatedAt time.Time
LastUsed *time.Time
Revoked bool
}
// IsExpired returns true if the token has expired
func (t *ProxyAccessToken) IsExpired() bool {
if t.ExpiresAt == nil {
return false
}
return time.Now().After(*t.ExpiresAt)
}
// IsValid returns true if the token is not revoked and not expired
func (t *ProxyAccessToken) IsValid() bool {
return !t.Revoked && !t.IsExpired()
}
// ProxyAccessTokenGenerated holds the new token and the plain text version
type ProxyAccessTokenGenerated struct {
PlainToken PlainProxyToken
ProxyAccessToken
}
// CreateNewProxyAccessToken generates a new proxy access token.
// Returns the token with hashed value stored and plain token for one-time display.
func CreateNewProxyAccessToken(name string, expiresIn time.Duration, accountID *string, createdBy string) (*ProxyAccessTokenGenerated, error) {
hashedToken, plainToken, err := generateProxyToken()
if err != nil {
return nil, err
}
currentTime := time.Now().UTC()
var expiresAt *time.Time
if expiresIn > 0 {
expiresAt = util.ToPtr(currentTime.Add(expiresIn))
}
return &ProxyAccessTokenGenerated{
ProxyAccessToken: ProxyAccessToken{
ID: xid.New().String(),
Name: name,
HashedToken: hashedToken,
AccountID: accountID,
ExpiresAt: expiresAt,
CreatedBy: createdBy,
CreatedAt: currentTime,
Revoked: false,
},
PlainToken: plainToken,
}, nil
}
func generateProxyToken() (HashedProxyToken, PlainProxyToken, error) {
secret, err := b.Random(ProxyTokenSecretLength)
if err != nil {
return "", "", err
}
checksum := crc32.ChecksumIEEE([]byte(secret))
encodedChecksum := base62.Encode(checksum)
paddedChecksum := fmt.Sprintf("%06s", encodedChecksum)
plainToken := PlainProxyToken(ProxyTokenPrefix + secret + paddedChecksum)
return plainToken.Hash(), plainToken, nil
}
// Hash returns the SHA-256 hash of the plain token, base64-encoded.
func (t PlainProxyToken) Hash() HashedProxyToken {
h := sha256.Sum256([]byte(t))
return HashedProxyToken(base64.StdEncoding.EncodeToString(h[:]))
}
// Validate checks the format of a proxy token without checking the database.
func (t PlainProxyToken) Validate() error {
if !strings.HasPrefix(string(t), ProxyTokenPrefix) {
return fmt.Errorf("invalid token prefix")
}
if len(t) != ProxyTokenLength {
return fmt.Errorf("invalid token length")
}
secret := t[len(ProxyTokenPrefix) : len(t)-ProxyTokenChecksumLength]
checksumStr := t[len(t)-ProxyTokenChecksumLength:]
expectedChecksum := crc32.ChecksumIEEE([]byte(secret))
expectedChecksumStr := fmt.Sprintf("%06s", base62.Encode(expectedChecksum))
if string(checksumStr) != expectedChecksumStr {
return fmt.Errorf("invalid token checksum")
}
return nil
}

View File

@@ -0,0 +1,155 @@
package types
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPlainProxyToken_Validate(t *testing.T) {
tests := []struct {
name string
token PlainProxyToken
wantErr bool
errMsg string
}{
{
name: "valid token",
token: "", // will be generated
wantErr: false,
},
{
name: "wrong prefix",
token: "xyz_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM",
wantErr: true,
errMsg: "invalid token prefix",
},
{
name: "too short",
token: "nbx_short",
wantErr: true,
errMsg: "invalid token length",
},
{
name: "too long",
token: "nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNMextra",
wantErr: true,
errMsg: "invalid token length",
},
{
name: "correct length but invalid checksum",
token: "nbx_invalidtoken123456789012345678901234", // exactly 40 chars, invalid checksum
wantErr: true,
errMsg: "invalid token checksum",
},
{
name: "empty token",
token: "",
wantErr: true,
errMsg: "invalid token prefix",
},
{
name: "only prefix",
token: "nbx_",
wantErr: true,
errMsg: "invalid token length",
},
}
// Generate a valid token for the first test
generated, err := CreateNewProxyAccessToken("test", 0, nil, "test")
require.NoError(t, err)
tests[0].token = generated.PlainToken
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.token.Validate()
if tt.wantErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestPlainProxyToken_Hash(t *testing.T) {
token1 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM")
token2 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM")
token3 := PlainProxyToken("nbx_differenttoken1234567890123456789X")
hash1 := token1.Hash()
hash2 := token2.Hash()
hash3 := token3.Hash()
assert.Equal(t, hash1, hash2, "same token should produce same hash")
assert.NotEqual(t, hash1, hash3, "different tokens should produce different hashes")
assert.NotEmpty(t, hash1)
}
func TestCreateNewProxyAccessToken(t *testing.T) {
t.Run("creates valid token", func(t *testing.T) {
generated, err := CreateNewProxyAccessToken("test-token", 0, nil, "test-user")
require.NoError(t, err)
assert.NotEmpty(t, generated.ID)
assert.Equal(t, "test-token", generated.Name)
assert.Equal(t, "test-user", generated.CreatedBy)
assert.NotEmpty(t, generated.HashedToken)
assert.NotEmpty(t, generated.PlainToken)
assert.Nil(t, generated.ExpiresAt)
assert.False(t, generated.Revoked)
assert.NoError(t, generated.PlainToken.Validate())
assert.Equal(t, ProxyTokenLength, len(generated.PlainToken))
assert.Equal(t, ProxyTokenPrefix, string(generated.PlainToken[:len(ProxyTokenPrefix)]))
})
t.Run("tokens are unique", func(t *testing.T) {
gen1, err := CreateNewProxyAccessToken("test1", 0, nil, "user")
require.NoError(t, err)
gen2, err := CreateNewProxyAccessToken("test2", 0, nil, "user")
require.NoError(t, err)
assert.NotEqual(t, gen1.PlainToken, gen2.PlainToken)
assert.NotEqual(t, gen1.HashedToken, gen2.HashedToken)
assert.NotEqual(t, gen1.ID, gen2.ID)
})
}
func TestProxyAccessToken_IsExpired(t *testing.T) {
past := time.Now().Add(-1 * time.Hour)
future := time.Now().Add(1 * time.Hour)
t.Run("expired token", func(t *testing.T) {
token := &ProxyAccessToken{ExpiresAt: &past}
assert.True(t, token.IsExpired())
})
t.Run("not expired token", func(t *testing.T) {
token := &ProxyAccessToken{ExpiresAt: &future}
assert.False(t, token.IsExpired())
})
t.Run("no expiration", func(t *testing.T) {
token := &ProxyAccessToken{ExpiresAt: nil}
assert.False(t, token.IsExpired())
})
}
func TestProxyAccessToken_IsValid(t *testing.T) {
token := &ProxyAccessToken{
Revoked: false,
}
assert.True(t, token.IsValid())
token.Revoked = true
assert.False(t, token.IsValid())
}

View File

@@ -2,6 +2,7 @@ package cmd
import (
"context"
"fmt"
"os"
"strconv"
"strings"
@@ -16,6 +17,9 @@ import (
const DefaultManagementURL = "https://api.netbird.io:443"
// envProxyToken is the environment variable name for the proxy access token.
const envProxyToken = "NB_PROXY_TOKEN"
var (
Version = "dev"
Commit = "unknown"
@@ -42,11 +46,12 @@ var (
)
var rootCmd = &cobra.Command{
Use: "proxy",
Short: "NetBird reverse proxy server",
Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.",
Version: Version,
RunE: runServer,
Use: "proxy",
Short: "NetBird reverse proxy server",
Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.",
Version: Version,
SilenceUsage: true,
RunE: runServer,
}
func init() {
@@ -85,6 +90,11 @@ func SetVersionInfo(version, commit, buildDate, goVersion string) {
}
func runServer(cmd *cobra.Command, args []string) error {
proxyToken := os.Getenv(envProxyToken)
if proxyToken == "" {
return fmt.Errorf("proxy token is required: set %s environment variable", envProxyToken)
}
level := "error"
if debugLogs {
level = "debug"
@@ -100,6 +110,7 @@ func runServer(cmd *cobra.Command, args []string) error {
Version: Version,
ManagementAddress: mgmtAddr,
ProxyURL: proxyURL,
ProxyToken: proxyToken,
CertificateDirectory: certDir,
GenerateACMECertificates: acmeCerts,
ACMEChallengeAddress: acmeAddr,

View File

@@ -49,6 +49,13 @@ spec:
value: "https://proxy.local"
- name: NB_PROXY_CERTIFICATE_DIRECTORY
value: "/certs"
- name: NB_PROXY_TOKEN
valueFrom:
secretKeyRef:
name: netbird-proxy-token
key: token
- name: NB_PROXY_ALLOW_INSECURE
value: "true" # Required for HTTP management connection in dev
volumeMounts:
- name: tls-certs
mountPath: /certs

View File

@@ -0,0 +1,48 @@
// Package grpc provides gRPC utilities for the proxy client.
package grpc
import (
"context"
"os"
"strconv"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// EnvProxyAllowInsecure controls whether the proxy token can be sent over non-TLS connections.
const EnvProxyAllowInsecure = "NB_PROXY_ALLOW_INSECURE"
var _ credentials.PerRPCCredentials = (*proxyAuthToken)(nil)
type proxyAuthToken struct {
token string
allowInsecure bool
}
func (t proxyAuthToken) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
return map[string]string{
"authorization": "Bearer " + t.token,
}, nil
}
// RequireTransportSecurity returns true by default to protect the token in transit.
// Set NB_PROXY_ALLOW_INSECURE=true to allow non-TLS connections (not recommended for production).
func (t proxyAuthToken) RequireTransportSecurity() bool {
return !t.allowInsecure
}
// WithProxyToken returns a DialOption that sets the proxy access token on each outbound RPC.
func WithProxyToken(token string) grpc.DialOption {
allowInsecure := false
if val := os.Getenv(EnvProxyAllowInsecure); val != "" {
parsed, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("invalid value for %s: %v", EnvProxyAllowInsecure, err)
} else {
allowInsecure = parsed
}
}
return grpc.WithPerRPCCredentials(proxyAuthToken{token: token, allowInsecure: allowInsecure})
}

View File

@@ -32,6 +32,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/debug"
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
@@ -76,6 +77,8 @@ type Server struct {
DebugEndpointAddress string
// HealthAddress is the address for the health probe endpoint (default: "localhost:8080").
HealthAddress string
// ProxyToken is the access token for authenticating with the management server.
ProxyToken string
}
// NotifyStatus sends a status update to management about tunnel connectivity
@@ -153,6 +156,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
Timeout: 10 * time.Second,
PermitWithoutStream: true,
}),
proxygrpc.WithProxyToken(s.ProxyToken),
)
if err != nil {
return fmt.Errorf("could not create management connection: %w", err)