mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Add proxy <-> management authentication
This commit is contained in:
@@ -80,4 +80,10 @@ func init() {
|
||||
migrationCmd.AddCommand(upCmd)
|
||||
|
||||
rootCmd.AddCommand(migrationCmd)
|
||||
|
||||
tokenCmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
tokenCmd.AddCommand(tokenCreateCmd)
|
||||
tokenCmd.AddCommand(tokenListCmd)
|
||||
tokenCmd.AddCommand(tokenRevokeCmd)
|
||||
rootCmd.AddCommand(tokenCmd)
|
||||
}
|
||||
|
||||
208
management/cmd/token.go
Normal file
208
management/cmd/token.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
tokenName string
|
||||
tokenExpireIn string
|
||||
tokenDatadir string
|
||||
|
||||
tokenCmd = &cobra.Command{
|
||||
Use: "token",
|
||||
Short: "Manage proxy access tokens",
|
||||
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
|
||||
}
|
||||
|
||||
tokenCreateCmd = &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a new proxy access token",
|
||||
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
|
||||
RunE: tokenCreateRun,
|
||||
}
|
||||
|
||||
tokenListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List all proxy access tokens",
|
||||
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
|
||||
RunE: tokenListRun,
|
||||
}
|
||||
|
||||
tokenRevokeCmd = &cobra.Command{
|
||||
Use: "revoke [token-id]",
|
||||
Short: "Revoke a proxy access token",
|
||||
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: tokenRevokeRun,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||
|
||||
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
|
||||
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
|
||||
tokenCreateCmd.MarkFlagRequired("name") //nolint
|
||||
}
|
||||
|
||||
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
|
||||
|
||||
config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
datadir := config.Datadir
|
||||
if tokenDatadir != "" {
|
||||
datadir = tokenDatadir
|
||||
}
|
||||
|
||||
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
|
||||
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
expiresIn, err := parseDuration(tokenExpireIn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse expiration: %w", err)
|
||||
}
|
||||
|
||||
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
|
||||
return fmt.Errorf("save token: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("Token created successfully!")
|
||||
fmt.Printf("Token: %s\n", generated.PlainToken)
|
||||
fmt.Println()
|
||||
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.")
|
||||
fmt.Printf("Token ID: %s\n", generated.ID)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func tokenListRun(cmd *cobra.Command, _ []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list tokens: %w", err)
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
fmt.Println("No proxy access tokens found.")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
|
||||
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
|
||||
|
||||
for _, t := range tokens {
|
||||
expires := "never"
|
||||
if t.ExpiresAt != nil {
|
||||
expires = t.ExpiresAt.Format("2006-01-02")
|
||||
}
|
||||
|
||||
lastUsed := "never"
|
||||
if t.LastUsed != nil {
|
||||
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
|
||||
}
|
||||
|
||||
revoked := "no"
|
||||
if t.Revoked {
|
||||
revoked = "yes"
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
t.ID,
|
||||
t.Name,
|
||||
t.CreatedAt.Format("2006-01-02"),
|
||||
expires,
|
||||
lastUsed,
|
||||
revoked,
|
||||
)
|
||||
}
|
||||
|
||||
w.Flush()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
|
||||
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
|
||||
tokenID := args[0]
|
||||
|
||||
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
|
||||
return fmt.Errorf("revoke token: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Token %s revoked successfully.\n", tokenID)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
|
||||
// An empty string returns zero duration (no expiration).
|
||||
func parseDuration(s string) (time.Duration, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if s[len(s)-1] == 'd' {
|
||||
d, err := strconv.Atoi(s[:len(s)-1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid day format: %s", s)
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return time.Duration(d) * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %s", s)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
101
management/cmd/token_test.go
Normal file
101
management/cmd/token_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseDuration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected time.Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty string returns zero",
|
||||
input: "",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "days suffix",
|
||||
input: "30d",
|
||||
expected: 30 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "one day",
|
||||
input: "1d",
|
||||
expected: 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "365 days",
|
||||
input: "365d",
|
||||
expected: 365 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "hours via Go duration",
|
||||
input: "24h",
|
||||
expected: 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "minutes via Go duration",
|
||||
input: "30m",
|
||||
expected: 30 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "complex Go duration",
|
||||
input: "1h30m",
|
||||
expected: 90 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "invalid day format",
|
||||
input: "abcd",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative days",
|
||||
input: "-1d",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero days",
|
||||
input: "0d",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-numeric days",
|
||||
input: "xyzd",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative Go duration",
|
||||
input: "-24h",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero Go duration",
|
||||
input: "0s",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid Go duration",
|
||||
input: "notaduration",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := parseDuration(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -123,11 +123,13 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
||||
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
||||
}
|
||||
proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store())
|
||||
s.proxyAuthClose = proxyAuthClose
|
||||
gRPCOpts := []grpc.ServerOption{
|
||||
grpc.KeepaliveEnforcementPolicy(kaep),
|
||||
grpc.KeepaliveParams(kasp),
|
||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
|
||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
|
||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary),
|
||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream),
|
||||
}
|
||||
|
||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||
|
||||
@@ -58,6 +58,8 @@ type BaseServer struct {
|
||||
mgmtMetricsPort int
|
||||
mgmtPort int
|
||||
|
||||
proxyAuthClose func()
|
||||
|
||||
listener net.Listener
|
||||
certManager *autocert.Manager
|
||||
update *version.Update
|
||||
@@ -215,6 +217,10 @@ func (s *BaseServer) Stop() error {
|
||||
_ = s.certManager.Listener().Close()
|
||||
}
|
||||
s.GRPCServer().Stop()
|
||||
if s.proxyAuthClose != nil {
|
||||
s.proxyAuthClose()
|
||||
s.proxyAuthClose = nil
|
||||
}
|
||||
_ = s.Store().Close(ctx)
|
||||
_ = s.EventStore().Close(ctx)
|
||||
if s.update != nil {
|
||||
|
||||
@@ -269,19 +269,27 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||
accessLog := req.GetLog()
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"reverse_proxy_id": accessLog.GetServiceId(),
|
||||
"account_id": accessLog.GetAccountId(),
|
||||
"host": accessLog.GetHost(),
|
||||
"path": accessLog.GetPath(),
|
||||
"method": accessLog.GetMethod(),
|
||||
"response_code": accessLog.GetResponseCode(),
|
||||
"duration_ms": accessLog.GetDurationMs(),
|
||||
"source_ip": accessLog.GetSourceIp(),
|
||||
"auth_mechanism": accessLog.GetAuthMechanism(),
|
||||
"user_id": accessLog.GetUserId(),
|
||||
"auth_success": accessLog.GetAuthSuccess(),
|
||||
}).Debug("Access log from proxy")
|
||||
}
|
||||
if mechanism := accessLog.GetAuthMechanism(); mechanism != "" {
|
||||
fields["auth_mechanism"] = mechanism
|
||||
}
|
||||
if userID := accessLog.GetUserId(); userID != "" {
|
||||
fields["user_id"] = userID
|
||||
}
|
||||
if !accessLog.GetAuthSuccess() {
|
||||
fields["auth_success"] = false
|
||||
}
|
||||
log.WithFields(fields).Debugf("%s %s %d (%dms)",
|
||||
accessLog.GetMethod(),
|
||||
accessLog.GetPath(),
|
||||
accessLog.GetResponseCode(),
|
||||
accessLog.GetDurationMs(),
|
||||
)
|
||||
|
||||
logEntry := &accesslogs.AccessLogEntry{}
|
||||
logEntry.FromProto(accessLog)
|
||||
|
||||
234
management/internals/shared/grpc/proxy_auth.go
Normal file
234
management/internals/shared/grpc/proxy_auth.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
// lastUsedUpdateInterval is the minimum interval between last_used updates for the same token.
|
||||
lastUsedUpdateInterval = time.Minute
|
||||
// lastUsedCleanupInterval is how often stale lastUsed entries are removed.
|
||||
lastUsedCleanupInterval = 2 * time.Minute
|
||||
)
|
||||
|
||||
type proxyTokenContextKey struct{}
|
||||
|
||||
// ProxyTokenContextKey is the typed key used to store validated token info in context.
|
||||
var ProxyTokenContextKey = proxyTokenContextKey{}
|
||||
|
||||
// proxyTokenID identifies a proxy access token by its database ID.
|
||||
type proxyTokenID = string
|
||||
|
||||
// proxyTokenStore defines the store interface needed for token validation
|
||||
type proxyTokenStore interface {
|
||||
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength store.LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
||||
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
||||
}
|
||||
|
||||
// proxyAuthInterceptor holds state for proxy authentication interceptors.
|
||||
type proxyAuthInterceptor struct {
|
||||
store proxyTokenStore
|
||||
failureLimiter *authFailureLimiter
|
||||
|
||||
// lastUsedMu protects lastUsedTimes
|
||||
lastUsedMu sync.Mutex
|
||||
lastUsedTimes map[proxyTokenID]time.Time
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newProxyAuthInterceptor(tokenStore proxyTokenStore) *proxyAuthInterceptor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
i := &proxyAuthInterceptor{
|
||||
store: tokenStore,
|
||||
failureLimiter: newAuthFailureLimiter(),
|
||||
lastUsedTimes: make(map[proxyTokenID]time.Time),
|
||||
cancel: cancel,
|
||||
}
|
||||
go i.lastUsedCleanupLoop(ctx)
|
||||
return i
|
||||
}
|
||||
|
||||
// NewProxyAuthInterceptors creates gRPC unary and stream interceptors that validate proxy access tokens.
|
||||
// They only intercept ProxyService methods. Both interceptors share state for last-used and failure rate limiting.
|
||||
// The returned close function must be called on shutdown to stop background goroutines.
|
||||
func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor, func()) {
|
||||
interceptor := newProxyAuthInterceptor(tokenStore)
|
||||
|
||||
unary := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
token, err := interceptor.validateProxyToken(ctx)
|
||||
if err != nil {
|
||||
// Log auth failures explicitly; gRPC doesn't log these by default.
|
||||
log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, ProxyTokenContextKey, token)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
stream := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
|
||||
return handler(srv, ss)
|
||||
}
|
||||
|
||||
token, err := interceptor.validateProxyToken(ss.Context())
|
||||
if err != nil {
|
||||
// Log auth failures explicitly; gRPC doesn't log these by default.
|
||||
log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token)
|
||||
wrapped := &wrappedServerStream{
|
||||
ServerStream: ss,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
return handler(srv, wrapped)
|
||||
}
|
||||
|
||||
return unary, stream, interceptor.close
|
||||
}
|
||||
|
||||
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
||||
clientIP := peerIPFromContext(ctx)
|
||||
|
||||
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
|
||||
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
|
||||
}
|
||||
|
||||
token, err := i.doValidateProxyToken(ctx)
|
||||
if err != nil {
|
||||
if clientIP != "" {
|
||||
i.failureLimiter.recordFailure(clientIP)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i.maybeUpdateLastUsed(ctx, token.ID)
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
|
||||
}
|
||||
|
||||
authValues := md.Get("authorization")
|
||||
if len(authValues) == 0 {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "missing authorization header")
|
||||
}
|
||||
|
||||
authValue := authValues[0]
|
||||
if !strings.HasPrefix(authValue, "Bearer ") {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid authorization format")
|
||||
}
|
||||
|
||||
plainToken := types.PlainProxyToken(strings.TrimPrefix(authValue, "Bearer "))
|
||||
|
||||
if err := plainToken.Validate(); err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid token format")
|
||||
}
|
||||
|
||||
token, err := i.store.GetProxyAccessTokenByHashedToken(ctx, store.LockingStrengthNone, plainToken.Hash())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
||||
}
|
||||
|
||||
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
|
||||
// Currently tokens are management-wide; AccountID field is reserved for future use.
|
||||
|
||||
if !token.IsValid() {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// maybeUpdateLastUsed updates the last_used timestamp if enough time has passed since the last update.
|
||||
func (i *proxyAuthInterceptor) maybeUpdateLastUsed(ctx context.Context, tokenID string) {
|
||||
now := time.Now()
|
||||
|
||||
i.lastUsedMu.Lock()
|
||||
lastUpdate, exists := i.lastUsedTimes[tokenID]
|
||||
if exists && now.Sub(lastUpdate) < lastUsedUpdateInterval {
|
||||
i.lastUsedMu.Unlock()
|
||||
return
|
||||
}
|
||||
i.lastUsedTimes[tokenID] = now
|
||||
i.lastUsedMu.Unlock()
|
||||
|
||||
if err := i.store.MarkProxyAccessTokenUsed(ctx, tokenID); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to mark proxy token as used: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *proxyAuthInterceptor) lastUsedCleanupLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(lastUsedCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
i.cleanupStaleLastUsed()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleLastUsed removes entries older than 2x the update interval.
|
||||
func (i *proxyAuthInterceptor) cleanupStaleLastUsed() {
|
||||
i.lastUsedMu.Lock()
|
||||
defer i.lastUsedMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
staleThreshold := 2 * lastUsedUpdateInterval
|
||||
for id, lastUpdate := range i.lastUsedTimes {
|
||||
if now.Sub(lastUpdate) > staleThreshold {
|
||||
delete(i.lastUsedTimes, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *proxyAuthInterceptor) close() {
|
||||
i.cancel()
|
||||
i.failureLimiter.stop()
|
||||
}
|
||||
|
||||
// GetProxyTokenFromContext retrieves the validated proxy token from the context
|
||||
func GetProxyTokenFromContext(ctx context.Context) *types.ProxyAccessToken {
|
||||
token, ok := ctx.Value(ProxyTokenContextKey).(*types.ProxyAccessToken)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// wrappedServerStream wraps a grpc.ServerStream to provide a custom context
|
||||
type wrappedServerStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (w *wrappedServerStream) Context() context.Context {
|
||||
return w.ctx
|
||||
}
|
||||
134
management/internals/shared/grpc/proxy_auth_ratelimit.go
Normal file
134
management/internals/shared/grpc/proxy_auth_ratelimit.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
"golang.org/x/time/rate"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
// proxyAuthFailureBurst is the maximum number of failed attempts before rate limiting kicks in.
|
||||
proxyAuthFailureBurst = 5
|
||||
// proxyAuthLimiterCleanup is how often stale limiters are removed.
|
||||
proxyAuthLimiterCleanup = 5 * time.Minute
|
||||
// proxyAuthLimiterTTL is how long a limiter is kept after the last failure.
|
||||
proxyAuthLimiterTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
// defaultProxyAuthFailureRate is the token replenishment rate for failed auth attempts.
|
||||
// One token every 12 seconds = 5 per minute.
|
||||
var defaultProxyAuthFailureRate = rate.Every(12 * time.Second)
|
||||
|
||||
// clientIP identifies a client by its IP address for rate limiting purposes.
|
||||
type clientIP = string
|
||||
|
||||
type limiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
lastAccess time.Time
|
||||
}
|
||||
|
||||
// authFailureLimiter tracks per-IP rate limits for failed proxy authentication attempts.
|
||||
type authFailureLimiter struct {
|
||||
mu sync.Mutex
|
||||
limiters map[clientIP]*limiterEntry
|
||||
failureRate rate.Limit
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newAuthFailureLimiter() *authFailureLimiter {
|
||||
return newAuthFailureLimiterWithRate(defaultProxyAuthFailureRate)
|
||||
}
|
||||
|
||||
func newAuthFailureLimiterWithRate(failureRate rate.Limit) *authFailureLimiter {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l := &authFailureLimiter{
|
||||
limiters: make(map[clientIP]*limiterEntry),
|
||||
failureRate: failureRate,
|
||||
cancel: cancel,
|
||||
}
|
||||
go l.cleanupLoop(ctx)
|
||||
return l
|
||||
}
|
||||
|
||||
// isLimited returns true if the given IP has exhausted its failure budget.
|
||||
func (l *authFailureLimiter) isLimited(ip clientIP) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
entry, exists := l.limiters[ip]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
return entry.limiter.Tokens() < 1
|
||||
}
|
||||
|
||||
// recordFailure consumes a token from the rate limiter for the given IP.
|
||||
func (l *authFailureLimiter) recordFailure(ip clientIP) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
entry, exists := l.limiters[ip]
|
||||
if !exists {
|
||||
entry = &limiterEntry{
|
||||
limiter: rate.NewLimiter(l.failureRate, proxyAuthFailureBurst),
|
||||
}
|
||||
l.limiters[ip] = entry
|
||||
}
|
||||
entry.lastAccess = now
|
||||
entry.limiter.Allow()
|
||||
}
|
||||
|
||||
func (l *authFailureLimiter) cleanupLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(proxyAuthLimiterCleanup)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
l.cleanup()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authFailureLimiter) cleanup() {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for ip, entry := range l.limiters {
|
||||
if now.Sub(entry.lastAccess) > proxyAuthLimiterTTL {
|
||||
delete(l.limiters, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authFailureLimiter) stop() {
|
||||
l.cancel()
|
||||
}
|
||||
|
||||
// peerIPFromContext extracts the client IP from the gRPC context.
|
||||
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
|
||||
func peerIPFromContext(ctx context.Context) clientIP {
|
||||
if addr, ok := realip.FromContext(ctx); ok {
|
||||
return addr.String()
|
||||
}
|
||||
|
||||
if p, ok := peer.FromContext(ctx); ok {
|
||||
host, _, err := net.SplitHostPort(p.Addr.String())
|
||||
if err != nil {
|
||||
return p.Addr.String()
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
|
||||
l := newAuthFailureLimiter()
|
||||
defer l.stop()
|
||||
|
||||
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
|
||||
}
|
||||
|
||||
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
|
||||
l := newAuthFailureLimiter()
|
||||
defer l.stop()
|
||||
|
||||
ip := "192.168.1.1"
|
||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||
l.recordFailure(ip)
|
||||
}
|
||||
|
||||
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
|
||||
}
|
||||
|
||||
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
|
||||
l := newAuthFailureLimiter()
|
||||
defer l.stop()
|
||||
|
||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||
l.recordFailure("192.168.1.1")
|
||||
}
|
||||
|
||||
assert.True(t, l.isLimited("192.168.1.1"))
|
||||
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
|
||||
}
|
||||
|
||||
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
|
||||
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
|
||||
defer l.stop()
|
||||
|
||||
ip := "10.0.0.1"
|
||||
|
||||
// Exhaust burst
|
||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||
l.recordFailure(ip)
|
||||
}
|
||||
require.True(t, l.isLimited(ip))
|
||||
|
||||
// Wait for token replenishment
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
|
||||
}
|
||||
|
||||
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
|
||||
l := newAuthFailureLimiter()
|
||||
defer l.stop()
|
||||
|
||||
l.recordFailure("10.0.0.1")
|
||||
|
||||
l.mu.Lock()
|
||||
require.Len(t, l.limiters, 1)
|
||||
// Backdate the entry so it looks stale
|
||||
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
||||
l.mu.Unlock()
|
||||
|
||||
l.cleanup()
|
||||
|
||||
l.mu.Lock()
|
||||
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
|
||||
l := newAuthFailureLimiter()
|
||||
defer l.stop()
|
||||
|
||||
l.recordFailure("10.0.0.1")
|
||||
l.recordFailure("10.0.0.2")
|
||||
|
||||
l.mu.Lock()
|
||||
// Only backdate one entry
|
||||
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
||||
l.mu.Unlock()
|
||||
|
||||
l.cleanup()
|
||||
|
||||
l.mu.Lock()
|
||||
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
|
||||
assert.Contains(t, l.limiters, "10.0.0.2")
|
||||
l.mu.Unlock()
|
||||
}
|
||||
@@ -126,7 +126,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
||||
}
|
||||
err = db.AutoMigrate(
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.ProxyAccessToken{},
|
||||
&types.Group{}, &types.GroupPeer{},
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
@@ -4309,6 +4310,79 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value.
|
||||
func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var token types.ProxyAccessToken
|
||||
result := tx.Take(&token, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "proxy access token not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "get proxy access token: %v", result.Error)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// GetAllProxyAccessTokens retrieves all proxy access tokens.
|
||||
func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var tokens []*types.ProxyAccessToken
|
||||
result := tx.Find(&tokens)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "get proxy access tokens: %v", result.Error)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// SaveProxyAccessToken saves a proxy access token to the database.
|
||||
func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error {
|
||||
if result := s.db.WithContext(ctx).Create(token); result.Error != nil {
|
||||
return status.Errorf(status.Internal, "save proxy access token: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeProxyAccessToken revokes a proxy access token by its ID.
|
||||
func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
|
||||
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "proxy access token not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
|
||||
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
|
||||
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
|
||||
Where(idQueryCondition, tokenID).
|
||||
Update("last_used", time.Now().UTC())
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "mark proxy access token as used: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "proxy access token not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
|
||||
@@ -109,6 +109,12 @@ type Store interface {
|
||||
SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error
|
||||
DeletePAT(ctx context.Context, userID, patID string) error
|
||||
|
||||
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
||||
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
|
||||
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
|
||||
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
|
||||
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
||||
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
||||
|
||||
137
management/server/types/proxy_access_token.go
Normal file
137
management/server/types/proxy_access_token.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
b "github.com/hashicorp/go-secure-stdlib/base62"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// ProxyTokenPrefix is the globally used prefix for proxy access tokens
|
||||
ProxyTokenPrefix = "nbx_"
|
||||
// ProxyTokenSecretLength is the number of characters used for the secret
|
||||
ProxyTokenSecretLength = 30
|
||||
// ProxyTokenChecksumLength is the number of characters used for the encoded checksum
|
||||
ProxyTokenChecksumLength = 6
|
||||
// ProxyTokenLength is the total number of characters used for the token
|
||||
ProxyTokenLength = 40
|
||||
)
|
||||
|
||||
// HashedProxyToken is a SHA-256 hash of a plain proxy token, base64-encoded.
|
||||
type HashedProxyToken string
|
||||
|
||||
// PlainProxyToken is the raw token string displayed once at creation time.
|
||||
type PlainProxyToken string
|
||||
|
||||
// ProxyAccessToken holds information about a proxy access token including a hashed version for verification
|
||||
type ProxyAccessToken struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
Name string
|
||||
HashedToken HashedProxyToken `gorm:"uniqueIndex"`
|
||||
// AccountID is nil for management-wide tokens, set for account-scoped tokens
|
||||
AccountID *string `gorm:"index"`
|
||||
ExpiresAt *time.Time
|
||||
CreatedBy string
|
||||
CreatedAt time.Time
|
||||
LastUsed *time.Time
|
||||
Revoked bool
|
||||
}
|
||||
|
||||
// IsExpired returns true if the token has expired
|
||||
func (t *ProxyAccessToken) IsExpired() bool {
|
||||
if t.ExpiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(*t.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsValid returns true if the token is not revoked and not expired
|
||||
func (t *ProxyAccessToken) IsValid() bool {
|
||||
return !t.Revoked && !t.IsExpired()
|
||||
}
|
||||
|
||||
// ProxyAccessTokenGenerated holds the new token and the plain text version
|
||||
type ProxyAccessTokenGenerated struct {
|
||||
PlainToken PlainProxyToken
|
||||
ProxyAccessToken
|
||||
}
|
||||
|
||||
// CreateNewProxyAccessToken generates a new proxy access token.
|
||||
// Returns the token with hashed value stored and plain token for one-time display.
|
||||
func CreateNewProxyAccessToken(name string, expiresIn time.Duration, accountID *string, createdBy string) (*ProxyAccessTokenGenerated, error) {
|
||||
hashedToken, plainToken, err := generateProxyToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentTime := time.Now().UTC()
|
||||
var expiresAt *time.Time
|
||||
if expiresIn > 0 {
|
||||
expiresAt = util.ToPtr(currentTime.Add(expiresIn))
|
||||
}
|
||||
|
||||
return &ProxyAccessTokenGenerated{
|
||||
ProxyAccessToken: ProxyAccessToken{
|
||||
ID: xid.New().String(),
|
||||
Name: name,
|
||||
HashedToken: hashedToken,
|
||||
AccountID: accountID,
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: currentTime,
|
||||
Revoked: false,
|
||||
},
|
||||
PlainToken: plainToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateProxyToken() (HashedProxyToken, PlainProxyToken, error) {
|
||||
secret, err := b.Random(ProxyTokenSecretLength)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
checksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
encodedChecksum := base62.Encode(checksum)
|
||||
paddedChecksum := fmt.Sprintf("%06s", encodedChecksum)
|
||||
plainToken := PlainProxyToken(ProxyTokenPrefix + secret + paddedChecksum)
|
||||
return plainToken.Hash(), plainToken, nil
|
||||
}
|
||||
|
||||
// Hash returns the SHA-256 hash of the plain token, base64-encoded.
|
||||
func (t PlainProxyToken) Hash() HashedProxyToken {
|
||||
h := sha256.Sum256([]byte(t))
|
||||
return HashedProxyToken(base64.StdEncoding.EncodeToString(h[:]))
|
||||
}
|
||||
|
||||
// Validate checks the format of a proxy token without checking the database.
|
||||
func (t PlainProxyToken) Validate() error {
|
||||
if !strings.HasPrefix(string(t), ProxyTokenPrefix) {
|
||||
return fmt.Errorf("invalid token prefix")
|
||||
}
|
||||
|
||||
if len(t) != ProxyTokenLength {
|
||||
return fmt.Errorf("invalid token length")
|
||||
}
|
||||
|
||||
secret := t[len(ProxyTokenPrefix) : len(t)-ProxyTokenChecksumLength]
|
||||
checksumStr := t[len(t)-ProxyTokenChecksumLength:]
|
||||
|
||||
expectedChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
expectedChecksumStr := fmt.Sprintf("%06s", base62.Encode(expectedChecksum))
|
||||
|
||||
if string(checksumStr) != expectedChecksumStr {
|
||||
return fmt.Errorf("invalid token checksum")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
155
management/server/types/proxy_access_token_test.go
Normal file
155
management/server/types/proxy_access_token_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPlainProxyToken_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token PlainProxyToken
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid token",
|
||||
token: "", // will be generated
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "wrong prefix",
|
||||
token: "xyz_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM",
|
||||
wantErr: true,
|
||||
errMsg: "invalid token prefix",
|
||||
},
|
||||
{
|
||||
name: "too short",
|
||||
token: "nbx_short",
|
||||
wantErr: true,
|
||||
errMsg: "invalid token length",
|
||||
},
|
||||
{
|
||||
name: "too long",
|
||||
token: "nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNMextra",
|
||||
wantErr: true,
|
||||
errMsg: "invalid token length",
|
||||
},
|
||||
{
|
||||
name: "correct length but invalid checksum",
|
||||
token: "nbx_invalidtoken123456789012345678901234", // exactly 40 chars, invalid checksum
|
||||
wantErr: true,
|
||||
errMsg: "invalid token checksum",
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
wantErr: true,
|
||||
errMsg: "invalid token prefix",
|
||||
},
|
||||
{
|
||||
name: "only prefix",
|
||||
token: "nbx_",
|
||||
wantErr: true,
|
||||
errMsg: "invalid token length",
|
||||
},
|
||||
}
|
||||
|
||||
// Generate a valid token for the first test
|
||||
generated, err := CreateNewProxyAccessToken("test", 0, nil, "test")
|
||||
require.NoError(t, err)
|
||||
tests[0].token = generated.PlainToken
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.token.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlainProxyToken_Hash(t *testing.T) {
|
||||
token1 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM")
|
||||
token2 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM")
|
||||
token3 := PlainProxyToken("nbx_differenttoken1234567890123456789X")
|
||||
|
||||
hash1 := token1.Hash()
|
||||
hash2 := token2.Hash()
|
||||
hash3 := token3.Hash()
|
||||
|
||||
assert.Equal(t, hash1, hash2, "same token should produce same hash")
|
||||
assert.NotEqual(t, hash1, hash3, "different tokens should produce different hashes")
|
||||
assert.NotEmpty(t, hash1)
|
||||
}
|
||||
|
||||
func TestCreateNewProxyAccessToken(t *testing.T) {
|
||||
t.Run("creates valid token", func(t *testing.T) {
|
||||
generated, err := CreateNewProxyAccessToken("test-token", 0, nil, "test-user")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, generated.ID)
|
||||
assert.Equal(t, "test-token", generated.Name)
|
||||
assert.Equal(t, "test-user", generated.CreatedBy)
|
||||
assert.NotEmpty(t, generated.HashedToken)
|
||||
assert.NotEmpty(t, generated.PlainToken)
|
||||
assert.Nil(t, generated.ExpiresAt)
|
||||
assert.False(t, generated.Revoked)
|
||||
|
||||
assert.NoError(t, generated.PlainToken.Validate())
|
||||
assert.Equal(t, ProxyTokenLength, len(generated.PlainToken))
|
||||
assert.Equal(t, ProxyTokenPrefix, string(generated.PlainToken[:len(ProxyTokenPrefix)]))
|
||||
})
|
||||
|
||||
t.Run("tokens are unique", func(t *testing.T) {
|
||||
gen1, err := CreateNewProxyAccessToken("test1", 0, nil, "user")
|
||||
require.NoError(t, err)
|
||||
|
||||
gen2, err := CreateNewProxyAccessToken("test2", 0, nil, "user")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, gen1.PlainToken, gen2.PlainToken)
|
||||
assert.NotEqual(t, gen1.HashedToken, gen2.HashedToken)
|
||||
assert.NotEqual(t, gen1.ID, gen2.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyAccessToken_IsExpired(t *testing.T) {
|
||||
past := time.Now().Add(-1 * time.Hour)
|
||||
future := time.Now().Add(1 * time.Hour)
|
||||
|
||||
t.Run("expired token", func(t *testing.T) {
|
||||
token := &ProxyAccessToken{ExpiresAt: &past}
|
||||
assert.True(t, token.IsExpired())
|
||||
})
|
||||
|
||||
t.Run("not expired token", func(t *testing.T) {
|
||||
token := &ProxyAccessToken{ExpiresAt: &future}
|
||||
assert.False(t, token.IsExpired())
|
||||
})
|
||||
|
||||
t.Run("no expiration", func(t *testing.T) {
|
||||
token := &ProxyAccessToken{ExpiresAt: nil}
|
||||
assert.False(t, token.IsExpired())
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyAccessToken_IsValid(t *testing.T) {
|
||||
token := &ProxyAccessToken{
|
||||
Revoked: false,
|
||||
}
|
||||
|
||||
assert.True(t, token.IsValid())
|
||||
|
||||
token.Revoked = true
|
||||
assert.False(t, token.IsValid())
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -16,6 +17,9 @@ import (
|
||||
|
||||
const DefaultManagementURL = "https://api.netbird.io:443"
|
||||
|
||||
// envProxyToken is the environment variable name for the proxy access token.
|
||||
const envProxyToken = "NB_PROXY_TOKEN"
|
||||
|
||||
var (
|
||||
Version = "dev"
|
||||
Commit = "unknown"
|
||||
@@ -42,11 +46,12 @@ var (
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "proxy",
|
||||
Short: "NetBird reverse proxy server",
|
||||
Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.",
|
||||
Version: Version,
|
||||
RunE: runServer,
|
||||
Use: "proxy",
|
||||
Short: "NetBird reverse proxy server",
|
||||
Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.",
|
||||
Version: Version,
|
||||
SilenceUsage: true,
|
||||
RunE: runServer,
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -85,6 +90,11 @@ func SetVersionInfo(version, commit, buildDate, goVersion string) {
|
||||
}
|
||||
|
||||
func runServer(cmd *cobra.Command, args []string) error {
|
||||
proxyToken := os.Getenv(envProxyToken)
|
||||
if proxyToken == "" {
|
||||
return fmt.Errorf("proxy token is required: set %s environment variable", envProxyToken)
|
||||
}
|
||||
|
||||
level := "error"
|
||||
if debugLogs {
|
||||
level = "debug"
|
||||
@@ -100,6 +110,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
Version: Version,
|
||||
ManagementAddress: mgmtAddr,
|
||||
ProxyURL: proxyURL,
|
||||
ProxyToken: proxyToken,
|
||||
CertificateDirectory: certDir,
|
||||
GenerateACMECertificates: acmeCerts,
|
||||
ACMEChallengeAddress: acmeAddr,
|
||||
|
||||
@@ -49,6 +49,13 @@ spec:
|
||||
value: "https://proxy.local"
|
||||
- name: NB_PROXY_CERTIFICATE_DIRECTORY
|
||||
value: "/certs"
|
||||
- name: NB_PROXY_TOKEN
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: netbird-proxy-token
|
||||
key: token
|
||||
- name: NB_PROXY_ALLOW_INSECURE
|
||||
value: "true" # Required for HTTP management connection in dev
|
||||
volumeMounts:
|
||||
- name: tls-certs
|
||||
mountPath: /certs
|
||||
|
||||
48
proxy/internal/grpc/auth.go
Normal file
48
proxy/internal/grpc/auth.go
Normal file
@@ -0,0 +1,48 @@
|
||||
// Package grpc provides gRPC utilities for the proxy client.
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
// EnvProxyAllowInsecure controls whether the proxy token can be sent over non-TLS connections.
|
||||
const EnvProxyAllowInsecure = "NB_PROXY_ALLOW_INSECURE"
|
||||
|
||||
var _ credentials.PerRPCCredentials = (*proxyAuthToken)(nil)
|
||||
|
||||
type proxyAuthToken struct {
|
||||
token string
|
||||
allowInsecure bool
|
||||
}
|
||||
|
||||
func (t proxyAuthToken) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
|
||||
return map[string]string{
|
||||
"authorization": "Bearer " + t.token,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RequireTransportSecurity returns true by default to protect the token in transit.
|
||||
// Set NB_PROXY_ALLOW_INSECURE=true to allow non-TLS connections (not recommended for production).
|
||||
func (t proxyAuthToken) RequireTransportSecurity() bool {
|
||||
return !t.allowInsecure
|
||||
}
|
||||
|
||||
// WithProxyToken returns a DialOption that sets the proxy access token on each outbound RPC.
|
||||
func WithProxyToken(token string) grpc.DialOption {
|
||||
allowInsecure := false
|
||||
if val := os.Getenv(EnvProxyAllowInsecure); val != "" {
|
||||
parsed, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("invalid value for %s: %v", EnvProxyAllowInsecure, err)
|
||||
} else {
|
||||
allowInsecure = parsed
|
||||
}
|
||||
}
|
||||
return grpc.WithPerRPCCredentials(proxyAuthToken{token: token, allowInsecure: allowInsecure})
|
||||
}
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
@@ -76,6 +77,8 @@ type Server struct {
|
||||
DebugEndpointAddress string
|
||||
// HealthAddress is the address for the health probe endpoint (default: "localhost:8080").
|
||||
HealthAddress string
|
||||
// ProxyToken is the access token for authenticating with the management server.
|
||||
ProxyToken string
|
||||
}
|
||||
|
||||
// NotifyStatus sends a status update to management about tunnel connectivity
|
||||
@@ -153,6 +156,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
Timeout: 10 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}),
|
||||
proxygrpc.WithProxyToken(s.ProxyToken),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create management connection: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user