From 0a3a9f977daecd81d581e833bb0cec0cb8deb568 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 6 Feb 2026 02:06:24 +0800 Subject: [PATCH] Add proxy <-> management authentication --- management/cmd/root.go | 6 + management/cmd/token.go | 208 ++++++++++++++++ management/cmd/token_test.go | 101 ++++++++ management/internals/server/boot.go | 6 +- management/internals/server/server.go | 6 + management/internals/shared/grpc/proxy.go | 26 +- .../internals/shared/grpc/proxy_auth.go | 234 ++++++++++++++++++ .../shared/grpc/proxy_auth_ratelimit.go | 134 ++++++++++ .../shared/grpc/proxy_auth_ratelimit_test.go | 98 ++++++++ management/server/store/sql_store.go | 76 +++++- management/server/store/store.go | 6 + management/server/types/proxy_access_token.go | 137 ++++++++++ .../server/types/proxy_access_token_test.go | 155 ++++++++++++ proxy/cmd/proxy/cmd/root.go | 21 +- proxy/deploy/k8s/deployment.yaml | 7 + proxy/internal/grpc/auth.go | 48 ++++ proxy/server.go | 4 + 17 files changed, 1256 insertions(+), 17 deletions(-) create mode 100644 management/cmd/token.go create mode 100644 management/cmd/token_test.go create mode 100644 management/internals/shared/grpc/proxy_auth.go create mode 100644 management/internals/shared/grpc/proxy_auth_ratelimit.go create mode 100644 management/internals/shared/grpc/proxy_auth_ratelimit_test.go create mode 100644 management/server/types/proxy_access_token.go create mode 100644 management/server/types/proxy_access_token_test.go create mode 100644 proxy/internal/grpc/auth.go diff --git a/management/cmd/root.go b/management/cmd/root.go index b60f79c23..2eca7859d 100644 --- a/management/cmd/root.go +++ b/management/cmd/root.go @@ -80,4 +80,10 @@ func init() { migrationCmd.AddCommand(upCmd) rootCmd.AddCommand(migrationCmd) + + tokenCmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location") + tokenCmd.AddCommand(tokenCreateCmd) + tokenCmd.AddCommand(tokenListCmd) + tokenCmd.AddCommand(tokenRevokeCmd) + rootCmd.AddCommand(tokenCmd) } diff --git a/management/cmd/token.go b/management/cmd/token.go new file mode 100644 index 000000000..6de193dbb --- /dev/null +++ b/management/cmd/token.go @@ -0,0 +1,208 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "strconv" + "text/tabwriter" + "time" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/formatter/hook" + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/util" +) + +var ( + tokenName string + tokenExpireIn string + tokenDatadir string + + tokenCmd = &cobra.Command{ + Use: "token", + Short: "Manage proxy access tokens", + Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.", + } + + tokenCreateCmd = &cobra.Command{ + Use: "create", + Short: "Create a new proxy access token", + Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.", + RunE: tokenCreateRun, + } + + tokenListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List all proxy access tokens", + Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.", + RunE: tokenListRun, + } + + tokenRevokeCmd = &cobra.Command{ + Use: "revoke [token-id]", + Short: "Revoke a proxy access token", + Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.", + Args: cobra.ExactArgs(1), + RunE: tokenRevokeRun, + } +) + +func init() { + tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)") + + tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)") + tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration") + tokenCreateCmd.MarkFlagRequired("name") //nolint +} + +// withTokenStore initializes logging, loads config, opens the store, and calls fn. +func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error { + if err := util.InitLog("error", "console"); err != nil { + return fmt.Errorf("init log: %w", err) + } + + ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) + + config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + datadir := config.Datadir + if tokenDatadir != "" { + datadir = tokenDatadir + } + + s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true) + if err != nil { + return fmt.Errorf("create store: %w", err) + } + defer func() { + if err := s.Close(ctx); err != nil { + log.Debugf("close store: %v", err) + } + }() + + return fn(ctx, s) +} + +func tokenCreateRun(cmd *cobra.Command, _ []string) error { + return withTokenStore(cmd, func(ctx context.Context, s store.Store) error { + expiresIn, err := parseDuration(tokenExpireIn) + if err != nil { + return fmt.Errorf("parse expiration: %w", err) + } + + generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI") + if err != nil { + return fmt.Errorf("generate token: %w", err) + } + + if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil { + return fmt.Errorf("save token: %w", err) + } + + fmt.Println("Token created successfully!") + fmt.Printf("Token: %s\n", generated.PlainToken) + fmt.Println() + fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") + fmt.Printf("Token ID: %s\n", generated.ID) + + return nil + }) +} + +func tokenListRun(cmd *cobra.Command, _ []string) error { + return withTokenStore(cmd, func(ctx context.Context, s store.Store) error { + tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone) + if err != nil { + return fmt.Errorf("list tokens: %w", err) + } + + if len(tokens) == 0 { + fmt.Println("No proxy access tokens found.") + return nil + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED") + fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------") + + for _, t := range tokens { + expires := "never" + if t.ExpiresAt != nil { + expires = t.ExpiresAt.Format("2006-01-02") + } + + lastUsed := "never" + if t.LastUsed != nil { + lastUsed = t.LastUsed.Format("2006-01-02 15:04") + } + + revoked := "no" + if t.Revoked { + revoked = "yes" + } + + fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n", + t.ID, + t.Name, + t.CreatedAt.Format("2006-01-02"), + expires, + lastUsed, + revoked, + ) + } + + w.Flush() + + return nil + }) +} + +func tokenRevokeRun(cmd *cobra.Command, args []string) error { + return withTokenStore(cmd, func(ctx context.Context, s store.Store) error { + tokenID := args[0] + + if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil { + return fmt.Errorf("revoke token: %w", err) + } + + fmt.Printf("Token %s revoked successfully.\n", tokenID) + return nil + }) +} + +// parseDuration parses a duration string with support for days (e.g., "30d", "365d"). +// An empty string returns zero duration (no expiration). +func parseDuration(s string) (time.Duration, error) { + if len(s) == 0 { + return 0, nil + } + + if s[len(s)-1] == 'd' { + d, err := strconv.Atoi(s[:len(s)-1]) + if err != nil { + return 0, fmt.Errorf("invalid day format: %s", s) + } + if d <= 0 { + return 0, fmt.Errorf("duration must be positive: %s", s) + } + return time.Duration(d) * 24 * time.Hour, nil + } + + d, err := time.ParseDuration(s) + if err != nil { + return 0, err + } + if d <= 0 { + return 0, fmt.Errorf("duration must be positive: %s", s) + } + return d, nil +} diff --git a/management/cmd/token_test.go b/management/cmd/token_test.go new file mode 100644 index 000000000..35ac0895e --- /dev/null +++ b/management/cmd/token_test.go @@ -0,0 +1,101 @@ +package cmd + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseDuration(t *testing.T) { + tests := []struct { + name string + input string + expected time.Duration + wantErr bool + }{ + { + name: "empty string returns zero", + input: "", + expected: 0, + }, + { + name: "days suffix", + input: "30d", + expected: 30 * 24 * time.Hour, + }, + { + name: "one day", + input: "1d", + expected: 24 * time.Hour, + }, + { + name: "365 days", + input: "365d", + expected: 365 * 24 * time.Hour, + }, + { + name: "hours via Go duration", + input: "24h", + expected: 24 * time.Hour, + }, + { + name: "minutes via Go duration", + input: "30m", + expected: 30 * time.Minute, + }, + { + name: "complex Go duration", + input: "1h30m", + expected: 90 * time.Minute, + }, + { + name: "invalid day format", + input: "abcd", + wantErr: true, + }, + { + name: "negative days", + input: "-1d", + wantErr: true, + }, + { + name: "zero days", + input: "0d", + wantErr: true, + }, + { + name: "non-numeric days", + input: "xyzd", + wantErr: true, + }, + { + name: "negative Go duration", + input: "-24h", + wantErr: true, + }, + { + name: "zero Go duration", + input: "0s", + wantErr: true, + }, + { + name: "invalid Go duration", + input: "notaduration", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseDuration(tt.input) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 02772f638..5233359a6 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -123,11 +123,13 @@ func (s *BaseServer) GRPCServer() *grpc.Server { realip.WithTrustedProxiesCount(trustedProxiesCount), realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}), } + proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store()) + s.proxyAuthClose = proxyAuthClose gRPCOpts := []grpc.ServerOption{ grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), - grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor), - grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor), + grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary), + grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream), } if s.Config.HttpConfig.LetsEncryptDomain != "" { diff --git a/management/internals/server/server.go b/management/internals/server/server.go index cd8d8e8fb..03667d419 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -58,6 +58,8 @@ type BaseServer struct { mgmtMetricsPort int mgmtPort int + proxyAuthClose func() + listener net.Listener certManager *autocert.Manager update *version.Update @@ -215,6 +217,10 @@ func (s *BaseServer) Stop() error { _ = s.certManager.Listener().Close() } s.GRPCServer().Stop() + if s.proxyAuthClose != nil { + s.proxyAuthClose() + s.proxyAuthClose = nil + } _ = s.Store().Close(ctx) _ = s.EventStore().Close(ctx) if s.update != nil { diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 8bc149e59..83fe6ebb8 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -269,19 +269,27 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) { accessLog := req.GetLog() - log.WithFields(log.Fields{ + fields := log.Fields{ "reverse_proxy_id": accessLog.GetServiceId(), "account_id": accessLog.GetAccountId(), "host": accessLog.GetHost(), - "path": accessLog.GetPath(), - "method": accessLog.GetMethod(), - "response_code": accessLog.GetResponseCode(), - "duration_ms": accessLog.GetDurationMs(), "source_ip": accessLog.GetSourceIp(), - "auth_mechanism": accessLog.GetAuthMechanism(), - "user_id": accessLog.GetUserId(), - "auth_success": accessLog.GetAuthSuccess(), - }).Debug("Access log from proxy") + } + if mechanism := accessLog.GetAuthMechanism(); mechanism != "" { + fields["auth_mechanism"] = mechanism + } + if userID := accessLog.GetUserId(); userID != "" { + fields["user_id"] = userID + } + if !accessLog.GetAuthSuccess() { + fields["auth_success"] = false + } + log.WithFields(fields).Debugf("%s %s %d (%dms)", + accessLog.GetMethod(), + accessLog.GetPath(), + accessLog.GetResponseCode(), + accessLog.GetDurationMs(), + ) logEntry := &accesslogs.AccessLogEntry{} logEntry.FromProto(accessLog) diff --git a/management/internals/shared/grpc/proxy_auth.go b/management/internals/shared/grpc/proxy_auth.go new file mode 100644 index 000000000..6daeab5f2 --- /dev/null +++ b/management/internals/shared/grpc/proxy_auth.go @@ -0,0 +1,234 @@ +package grpc + +import ( + "context" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +const ( + // lastUsedUpdateInterval is the minimum interval between last_used updates for the same token. + lastUsedUpdateInterval = time.Minute + // lastUsedCleanupInterval is how often stale lastUsed entries are removed. + lastUsedCleanupInterval = 2 * time.Minute +) + +type proxyTokenContextKey struct{} + +// ProxyTokenContextKey is the typed key used to store validated token info in context. +var ProxyTokenContextKey = proxyTokenContextKey{} + +// proxyTokenID identifies a proxy access token by its database ID. +type proxyTokenID = string + +// proxyTokenStore defines the store interface needed for token validation +type proxyTokenStore interface { + GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength store.LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) + MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error +} + +// proxyAuthInterceptor holds state for proxy authentication interceptors. +type proxyAuthInterceptor struct { + store proxyTokenStore + failureLimiter *authFailureLimiter + + // lastUsedMu protects lastUsedTimes + lastUsedMu sync.Mutex + lastUsedTimes map[proxyTokenID]time.Time + cancel context.CancelFunc +} + +func newProxyAuthInterceptor(tokenStore proxyTokenStore) *proxyAuthInterceptor { + ctx, cancel := context.WithCancel(context.Background()) + i := &proxyAuthInterceptor{ + store: tokenStore, + failureLimiter: newAuthFailureLimiter(), + lastUsedTimes: make(map[proxyTokenID]time.Time), + cancel: cancel, + } + go i.lastUsedCleanupLoop(ctx) + return i +} + +// NewProxyAuthInterceptors creates gRPC unary and stream interceptors that validate proxy access tokens. +// They only intercept ProxyService methods. Both interceptors share state for last-used and failure rate limiting. +// The returned close function must be called on shutdown to stop background goroutines. +func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor, func()) { + interceptor := newProxyAuthInterceptor(tokenStore) + + unary := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") { + return handler(ctx, req) + } + + token, err := interceptor.validateProxyToken(ctx) + if err != nil { + // Log auth failures explicitly; gRPC doesn't log these by default. + log.WithContext(ctx).Warnf("proxy auth failed: %v", err) + return nil, err + } + + ctx = context.WithValue(ctx, ProxyTokenContextKey, token) + return handler(ctx, req) + } + + stream := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") { + return handler(srv, ss) + } + + token, err := interceptor.validateProxyToken(ss.Context()) + if err != nil { + // Log auth failures explicitly; gRPC doesn't log these by default. + log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err) + return err + } + + ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token) + wrapped := &wrappedServerStream{ + ServerStream: ss, + ctx: ctx, + } + + return handler(srv, wrapped) + } + + return unary, stream, interceptor.close +} + +func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) { + clientIP := peerIPFromContext(ctx) + + if clientIP != "" && i.failureLimiter.isLimited(clientIP) { + return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts") + } + + token, err := i.doValidateProxyToken(ctx) + if err != nil { + if clientIP != "" { + i.failureLimiter.recordFailure(clientIP) + } + return nil, err + } + + i.maybeUpdateLastUsed(ctx, token.ID) + + return token, nil +} + +func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Unauthenticated, "missing metadata") + } + + authValues := md.Get("authorization") + if len(authValues) == 0 { + return nil, status.Errorf(codes.Unauthenticated, "missing authorization header") + } + + authValue := authValues[0] + if !strings.HasPrefix(authValue, "Bearer ") { + return nil, status.Errorf(codes.Unauthenticated, "invalid authorization format") + } + + plainToken := types.PlainProxyToken(strings.TrimPrefix(authValue, "Bearer ")) + + if err := plainToken.Validate(); err != nil { + return nil, status.Errorf(codes.Unauthenticated, "invalid token format") + } + + token, err := i.store.GetProxyAccessTokenByHashedToken(ctx, store.LockingStrengthNone, plainToken.Hash()) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "invalid token") + } + + // TODO: Enforce AccountID scope for "bring your own proxy" feature. + // Currently tokens are management-wide; AccountID field is reserved for future use. + + if !token.IsValid() { + return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked") + } + + return token, nil +} + +// maybeUpdateLastUsed updates the last_used timestamp if enough time has passed since the last update. +func (i *proxyAuthInterceptor) maybeUpdateLastUsed(ctx context.Context, tokenID string) { + now := time.Now() + + i.lastUsedMu.Lock() + lastUpdate, exists := i.lastUsedTimes[tokenID] + if exists && now.Sub(lastUpdate) < lastUsedUpdateInterval { + i.lastUsedMu.Unlock() + return + } + i.lastUsedTimes[tokenID] = now + i.lastUsedMu.Unlock() + + if err := i.store.MarkProxyAccessTokenUsed(ctx, tokenID); err != nil { + log.WithContext(ctx).Debugf("failed to mark proxy token as used: %v", err) + } +} + +func (i *proxyAuthInterceptor) lastUsedCleanupLoop(ctx context.Context) { + ticker := time.NewTicker(lastUsedCleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + i.cleanupStaleLastUsed() + case <-ctx.Done(): + return + } + } +} + +// cleanupStaleLastUsed removes entries older than 2x the update interval. +func (i *proxyAuthInterceptor) cleanupStaleLastUsed() { + i.lastUsedMu.Lock() + defer i.lastUsedMu.Unlock() + + now := time.Now() + staleThreshold := 2 * lastUsedUpdateInterval + for id, lastUpdate := range i.lastUsedTimes { + if now.Sub(lastUpdate) > staleThreshold { + delete(i.lastUsedTimes, id) + } + } +} + +func (i *proxyAuthInterceptor) close() { + i.cancel() + i.failureLimiter.stop() +} + +// GetProxyTokenFromContext retrieves the validated proxy token from the context +func GetProxyTokenFromContext(ctx context.Context) *types.ProxyAccessToken { + token, ok := ctx.Value(ProxyTokenContextKey).(*types.ProxyAccessToken) + if !ok { + return nil + } + return token +} + +// wrappedServerStream wraps a grpc.ServerStream to provide a custom context +type wrappedServerStream struct { + grpc.ServerStream + ctx context.Context +} + +func (w *wrappedServerStream) Context() context.Context { + return w.ctx +} diff --git a/management/internals/shared/grpc/proxy_auth_ratelimit.go b/management/internals/shared/grpc/proxy_auth_ratelimit.go new file mode 100644 index 000000000..447e531b0 --- /dev/null +++ b/management/internals/shared/grpc/proxy_auth_ratelimit.go @@ -0,0 +1,134 @@ +package grpc + +import ( + "context" + "net" + "sync" + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" + "golang.org/x/time/rate" + "google.golang.org/grpc/peer" +) + +const ( + // proxyAuthFailureBurst is the maximum number of failed attempts before rate limiting kicks in. + proxyAuthFailureBurst = 5 + // proxyAuthLimiterCleanup is how often stale limiters are removed. + proxyAuthLimiterCleanup = 5 * time.Minute + // proxyAuthLimiterTTL is how long a limiter is kept after the last failure. + proxyAuthLimiterTTL = 15 * time.Minute +) + +// defaultProxyAuthFailureRate is the token replenishment rate for failed auth attempts. +// One token every 12 seconds = 5 per minute. +var defaultProxyAuthFailureRate = rate.Every(12 * time.Second) + +// clientIP identifies a client by its IP address for rate limiting purposes. +type clientIP = string + +type limiterEntry struct { + limiter *rate.Limiter + lastAccess time.Time +} + +// authFailureLimiter tracks per-IP rate limits for failed proxy authentication attempts. +type authFailureLimiter struct { + mu sync.Mutex + limiters map[clientIP]*limiterEntry + failureRate rate.Limit + cancel context.CancelFunc +} + +func newAuthFailureLimiter() *authFailureLimiter { + return newAuthFailureLimiterWithRate(defaultProxyAuthFailureRate) +} + +func newAuthFailureLimiterWithRate(failureRate rate.Limit) *authFailureLimiter { + ctx, cancel := context.WithCancel(context.Background()) + l := &authFailureLimiter{ + limiters: make(map[clientIP]*limiterEntry), + failureRate: failureRate, + cancel: cancel, + } + go l.cleanupLoop(ctx) + return l +} + +// isLimited returns true if the given IP has exhausted its failure budget. +func (l *authFailureLimiter) isLimited(ip clientIP) bool { + l.mu.Lock() + defer l.mu.Unlock() + + entry, exists := l.limiters[ip] + if !exists { + return false + } + + return entry.limiter.Tokens() < 1 +} + +// recordFailure consumes a token from the rate limiter for the given IP. +func (l *authFailureLimiter) recordFailure(ip clientIP) { + l.mu.Lock() + defer l.mu.Unlock() + + now := time.Now() + entry, exists := l.limiters[ip] + if !exists { + entry = &limiterEntry{ + limiter: rate.NewLimiter(l.failureRate, proxyAuthFailureBurst), + } + l.limiters[ip] = entry + } + entry.lastAccess = now + entry.limiter.Allow() +} + +func (l *authFailureLimiter) cleanupLoop(ctx context.Context) { + ticker := time.NewTicker(proxyAuthLimiterCleanup) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + l.cleanup() + case <-ctx.Done(): + return + } + } +} + +func (l *authFailureLimiter) cleanup() { + l.mu.Lock() + defer l.mu.Unlock() + + now := time.Now() + for ip, entry := range l.limiters { + if now.Sub(entry.lastAccess) > proxyAuthLimiterTTL { + delete(l.limiters, ip) + } + } +} + +func (l *authFailureLimiter) stop() { + l.cancel() +} + +// peerIPFromContext extracts the client IP from the gRPC context. +// Uses realip (from trusted proxy headers) first, falls back to the transport peer address. +func peerIPFromContext(ctx context.Context) clientIP { + if addr, ok := realip.FromContext(ctx); ok { + return addr.String() + } + + if p, ok := peer.FromContext(ctx); ok { + host, _, err := net.SplitHostPort(p.Addr.String()) + if err != nil { + return p.Addr.String() + } + return host + } + + return "" +} diff --git a/management/internals/shared/grpc/proxy_auth_ratelimit_test.go b/management/internals/shared/grpc/proxy_auth_ratelimit_test.go new file mode 100644 index 000000000..3577baeb8 --- /dev/null +++ b/management/internals/shared/grpc/proxy_auth_ratelimit_test.go @@ -0,0 +1,98 @@ +package grpc + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" +) + +func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) { + l := newAuthFailureLimiter() + defer l.stop() + + assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited") +} + +func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) { + l := newAuthFailureLimiter() + defer l.stop() + + ip := "192.168.1.1" + for i := 0; i < proxyAuthFailureBurst; i++ { + l.recordFailure(ip) + } + + assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst") +} + +func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) { + l := newAuthFailureLimiter() + defer l.stop() + + for i := 0; i < proxyAuthFailureBurst; i++ { + l.recordFailure("192.168.1.1") + } + + assert.True(t, l.isLimited("192.168.1.1")) + assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected") +} + +func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) { + l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery + defer l.stop() + + ip := "10.0.0.1" + + // Exhaust burst + for i := 0; i < proxyAuthFailureBurst; i++ { + l.recordFailure(ip) + } + require.True(t, l.isLimited(ip)) + + // Wait for token replenishment + time.Sleep(50 * time.Millisecond) + + assert.False(t, l.isLimited(ip), "should recover after tokens replenish") +} + +func TestAuthFailureLimiter_Cleanup(t *testing.T) { + l := newAuthFailureLimiter() + defer l.stop() + + l.recordFailure("10.0.0.1") + + l.mu.Lock() + require.Len(t, l.limiters, 1) + // Backdate the entry so it looks stale + l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute) + l.mu.Unlock() + + l.cleanup() + + l.mu.Lock() + assert.Empty(t, l.limiters, "stale entries should be cleaned up") + l.mu.Unlock() +} + +func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) { + l := newAuthFailureLimiter() + defer l.stop() + + l.recordFailure("10.0.0.1") + l.recordFailure("10.0.0.2") + + l.mu.Lock() + // Only backdate one entry + l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute) + l.mu.Unlock() + + l.cleanup() + + l.mu.Lock() + assert.Len(t, l.limiters, 1, "only stale entries should be removed") + assert.Contains(t, l.limiters, "10.0.0.2") + l.mu.Unlock() +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index b82739a06..317274327 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -126,7 +126,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met return nil, fmt.Errorf("migratePreAuto: %w", err) } err = db.AutoMigrate( - &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, + &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.ProxyAccessToken{}, + &types.Group{}, &types.GroupPeer{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, @@ -4309,6 +4310,79 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error { return nil } +// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value. +func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var token types.ProxyAccessToken + result := tx.Take(&token, "hashed_token = ?", hashedToken) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "proxy access token not found") + } + return nil, status.Errorf(status.Internal, "get proxy access token: %v", result.Error) + } + + return &token, nil +} + +// GetAllProxyAccessTokens retrieves all proxy access tokens. +func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var tokens []*types.ProxyAccessToken + result := tx.Find(&tokens) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "get proxy access tokens: %v", result.Error) + } + + return tokens, nil +} + +// SaveProxyAccessToken saves a proxy access token to the database. +func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error { + if result := s.db.WithContext(ctx).Create(token); result.Error != nil { + return status.Errorf(status.Internal, "save proxy access token: %v", result.Error) + } + return nil +} + +// RevokeProxyAccessToken revokes a proxy access token by its ID. +func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error { + result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true) + if result.Error != nil { + return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error) + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "proxy access token not found") + } + + return nil +} + +// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token. +func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error { + result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}). + Where(idQueryCondition, tokenID). + Update("last_used", time.Now().UTC()) + if result.Error != nil { + return status.Errorf(status.Internal, "mark proxy access token as used: %v", result.Error) + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "proxy access token not found") + } + + return nil +} + func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) { tx := s.db if lockStrength != LockingStrengthNone { diff --git a/management/server/store/store.go b/management/server/store/store.go index feb403a38..d2fd059e0 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -109,6 +109,12 @@ type Store interface { SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error DeletePAT(ctx context.Context, userID, patID string) error + GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) + GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) + SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error + RevokeProxyAccessToken(ctx context.Context, tokenID string) error + MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error + GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) diff --git a/management/server/types/proxy_access_token.go b/management/server/types/proxy_access_token.go new file mode 100644 index 000000000..3a52eb735 --- /dev/null +++ b/management/server/types/proxy_access_token.go @@ -0,0 +1,137 @@ +package types + +import ( + "crypto/sha256" + "encoding/base64" + "fmt" + "hash/crc32" + "strings" + "time" + + b "github.com/hashicorp/go-secure-stdlib/base62" + "github.com/rs/xid" + + "github.com/netbirdio/netbird/base62" + "github.com/netbirdio/netbird/management/server/util" +) + +const ( + // ProxyTokenPrefix is the globally used prefix for proxy access tokens + ProxyTokenPrefix = "nbx_" + // ProxyTokenSecretLength is the number of characters used for the secret + ProxyTokenSecretLength = 30 + // ProxyTokenChecksumLength is the number of characters used for the encoded checksum + ProxyTokenChecksumLength = 6 + // ProxyTokenLength is the total number of characters used for the token + ProxyTokenLength = 40 +) + +// HashedProxyToken is a SHA-256 hash of a plain proxy token, base64-encoded. +type HashedProxyToken string + +// PlainProxyToken is the raw token string displayed once at creation time. +type PlainProxyToken string + +// ProxyAccessToken holds information about a proxy access token including a hashed version for verification +type ProxyAccessToken struct { + ID string `gorm:"primaryKey"` + Name string + HashedToken HashedProxyToken `gorm:"uniqueIndex"` + // AccountID is nil for management-wide tokens, set for account-scoped tokens + AccountID *string `gorm:"index"` + ExpiresAt *time.Time + CreatedBy string + CreatedAt time.Time + LastUsed *time.Time + Revoked bool +} + +// IsExpired returns true if the token has expired +func (t *ProxyAccessToken) IsExpired() bool { + if t.ExpiresAt == nil { + return false + } + return time.Now().After(*t.ExpiresAt) +} + +// IsValid returns true if the token is not revoked and not expired +func (t *ProxyAccessToken) IsValid() bool { + return !t.Revoked && !t.IsExpired() +} + +// ProxyAccessTokenGenerated holds the new token and the plain text version +type ProxyAccessTokenGenerated struct { + PlainToken PlainProxyToken + ProxyAccessToken +} + +// CreateNewProxyAccessToken generates a new proxy access token. +// Returns the token with hashed value stored and plain token for one-time display. +func CreateNewProxyAccessToken(name string, expiresIn time.Duration, accountID *string, createdBy string) (*ProxyAccessTokenGenerated, error) { + hashedToken, plainToken, err := generateProxyToken() + if err != nil { + return nil, err + } + + currentTime := time.Now().UTC() + var expiresAt *time.Time + if expiresIn > 0 { + expiresAt = util.ToPtr(currentTime.Add(expiresIn)) + } + + return &ProxyAccessTokenGenerated{ + ProxyAccessToken: ProxyAccessToken{ + ID: xid.New().String(), + Name: name, + HashedToken: hashedToken, + AccountID: accountID, + ExpiresAt: expiresAt, + CreatedBy: createdBy, + CreatedAt: currentTime, + Revoked: false, + }, + PlainToken: plainToken, + }, nil +} + +func generateProxyToken() (HashedProxyToken, PlainProxyToken, error) { + secret, err := b.Random(ProxyTokenSecretLength) + if err != nil { + return "", "", err + } + + checksum := crc32.ChecksumIEEE([]byte(secret)) + encodedChecksum := base62.Encode(checksum) + paddedChecksum := fmt.Sprintf("%06s", encodedChecksum) + plainToken := PlainProxyToken(ProxyTokenPrefix + secret + paddedChecksum) + return plainToken.Hash(), plainToken, nil +} + +// Hash returns the SHA-256 hash of the plain token, base64-encoded. +func (t PlainProxyToken) Hash() HashedProxyToken { + h := sha256.Sum256([]byte(t)) + return HashedProxyToken(base64.StdEncoding.EncodeToString(h[:])) +} + +// Validate checks the format of a proxy token without checking the database. +func (t PlainProxyToken) Validate() error { + if !strings.HasPrefix(string(t), ProxyTokenPrefix) { + return fmt.Errorf("invalid token prefix") + } + + if len(t) != ProxyTokenLength { + return fmt.Errorf("invalid token length") + } + + secret := t[len(ProxyTokenPrefix) : len(t)-ProxyTokenChecksumLength] + checksumStr := t[len(t)-ProxyTokenChecksumLength:] + + expectedChecksum := crc32.ChecksumIEEE([]byte(secret)) + expectedChecksumStr := fmt.Sprintf("%06s", base62.Encode(expectedChecksum)) + + if string(checksumStr) != expectedChecksumStr { + return fmt.Errorf("invalid token checksum") + } + + return nil +} diff --git a/management/server/types/proxy_access_token_test.go b/management/server/types/proxy_access_token_test.go new file mode 100644 index 000000000..aa1a4d2dd --- /dev/null +++ b/management/server/types/proxy_access_token_test.go @@ -0,0 +1,155 @@ +package types + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPlainProxyToken_Validate(t *testing.T) { + tests := []struct { + name string + token PlainProxyToken + wantErr bool + errMsg string + }{ + { + name: "valid token", + token: "", // will be generated + wantErr: false, + }, + { + name: "wrong prefix", + token: "xyz_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM", + wantErr: true, + errMsg: "invalid token prefix", + }, + { + name: "too short", + token: "nbx_short", + wantErr: true, + errMsg: "invalid token length", + }, + { + name: "too long", + token: "nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNMextra", + wantErr: true, + errMsg: "invalid token length", + }, + { + name: "correct length but invalid checksum", + token: "nbx_invalidtoken123456789012345678901234", // exactly 40 chars, invalid checksum + wantErr: true, + errMsg: "invalid token checksum", + }, + { + name: "empty token", + token: "", + wantErr: true, + errMsg: "invalid token prefix", + }, + { + name: "only prefix", + token: "nbx_", + wantErr: true, + errMsg: "invalid token length", + }, + } + + // Generate a valid token for the first test + generated, err := CreateNewProxyAccessToken("test", 0, nil, "test") + require.NoError(t, err) + tests[0].token = generated.PlainToken + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.token.Validate() + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPlainProxyToken_Hash(t *testing.T) { + token1 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM") + token2 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM") + token3 := PlainProxyToken("nbx_differenttoken1234567890123456789X") + + hash1 := token1.Hash() + hash2 := token2.Hash() + hash3 := token3.Hash() + + assert.Equal(t, hash1, hash2, "same token should produce same hash") + assert.NotEqual(t, hash1, hash3, "different tokens should produce different hashes") + assert.NotEmpty(t, hash1) +} + +func TestCreateNewProxyAccessToken(t *testing.T) { + t.Run("creates valid token", func(t *testing.T) { + generated, err := CreateNewProxyAccessToken("test-token", 0, nil, "test-user") + require.NoError(t, err) + + assert.NotEmpty(t, generated.ID) + assert.Equal(t, "test-token", generated.Name) + assert.Equal(t, "test-user", generated.CreatedBy) + assert.NotEmpty(t, generated.HashedToken) + assert.NotEmpty(t, generated.PlainToken) + assert.Nil(t, generated.ExpiresAt) + assert.False(t, generated.Revoked) + + assert.NoError(t, generated.PlainToken.Validate()) + assert.Equal(t, ProxyTokenLength, len(generated.PlainToken)) + assert.Equal(t, ProxyTokenPrefix, string(generated.PlainToken[:len(ProxyTokenPrefix)])) + }) + + t.Run("tokens are unique", func(t *testing.T) { + gen1, err := CreateNewProxyAccessToken("test1", 0, nil, "user") + require.NoError(t, err) + + gen2, err := CreateNewProxyAccessToken("test2", 0, nil, "user") + require.NoError(t, err) + + assert.NotEqual(t, gen1.PlainToken, gen2.PlainToken) + assert.NotEqual(t, gen1.HashedToken, gen2.HashedToken) + assert.NotEqual(t, gen1.ID, gen2.ID) + }) +} + +func TestProxyAccessToken_IsExpired(t *testing.T) { + past := time.Now().Add(-1 * time.Hour) + future := time.Now().Add(1 * time.Hour) + + t.Run("expired token", func(t *testing.T) { + token := &ProxyAccessToken{ExpiresAt: &past} + assert.True(t, token.IsExpired()) + }) + + t.Run("not expired token", func(t *testing.T) { + token := &ProxyAccessToken{ExpiresAt: &future} + assert.False(t, token.IsExpired()) + }) + + t.Run("no expiration", func(t *testing.T) { + token := &ProxyAccessToken{ExpiresAt: nil} + assert.False(t, token.IsExpired()) + }) +} + +func TestProxyAccessToken_IsValid(t *testing.T) { + token := &ProxyAccessToken{ + Revoked: false, + } + + assert.True(t, token.IsValid()) + + token.Revoked = true + assert.False(t, token.IsValid()) +} diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 0a8cd6de5..e128440da 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "fmt" "os" "strconv" "strings" @@ -16,6 +17,9 @@ import ( const DefaultManagementURL = "https://api.netbird.io:443" +// envProxyToken is the environment variable name for the proxy access token. +const envProxyToken = "NB_PROXY_TOKEN" + var ( Version = "dev" Commit = "unknown" @@ -42,11 +46,12 @@ var ( ) var rootCmd = &cobra.Command{ - Use: "proxy", - Short: "NetBird reverse proxy server", - Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.", - Version: Version, - RunE: runServer, + Use: "proxy", + Short: "NetBird reverse proxy server", + Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.", + Version: Version, + SilenceUsage: true, + RunE: runServer, } func init() { @@ -85,6 +90,11 @@ func SetVersionInfo(version, commit, buildDate, goVersion string) { } func runServer(cmd *cobra.Command, args []string) error { + proxyToken := os.Getenv(envProxyToken) + if proxyToken == "" { + return fmt.Errorf("proxy token is required: set %s environment variable", envProxyToken) + } + level := "error" if debugLogs { level = "debug" @@ -100,6 +110,7 @@ func runServer(cmd *cobra.Command, args []string) error { Version: Version, ManagementAddress: mgmtAddr, ProxyURL: proxyURL, + ProxyToken: proxyToken, CertificateDirectory: certDir, GenerateACMECertificates: acmeCerts, ACMEChallengeAddress: acmeAddr, diff --git a/proxy/deploy/k8s/deployment.yaml b/proxy/deploy/k8s/deployment.yaml index 94b1e4e9e..0611c541b 100644 --- a/proxy/deploy/k8s/deployment.yaml +++ b/proxy/deploy/k8s/deployment.yaml @@ -49,6 +49,13 @@ spec: value: "https://proxy.local" - name: NB_PROXY_CERTIFICATE_DIRECTORY value: "/certs" + - name: NB_PROXY_TOKEN + valueFrom: + secretKeyRef: + name: netbird-proxy-token + key: token + - name: NB_PROXY_ALLOW_INSECURE + value: "true" # Required for HTTP management connection in dev volumeMounts: - name: tls-certs mountPath: /certs diff --git a/proxy/internal/grpc/auth.go b/proxy/internal/grpc/auth.go new file mode 100644 index 000000000..ce1a23f68 --- /dev/null +++ b/proxy/internal/grpc/auth.go @@ -0,0 +1,48 @@ +// Package grpc provides gRPC utilities for the proxy client. +package grpc + +import ( + "context" + "os" + "strconv" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +// EnvProxyAllowInsecure controls whether the proxy token can be sent over non-TLS connections. +const EnvProxyAllowInsecure = "NB_PROXY_ALLOW_INSECURE" + +var _ credentials.PerRPCCredentials = (*proxyAuthToken)(nil) + +type proxyAuthToken struct { + token string + allowInsecure bool +} + +func (t proxyAuthToken) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { + return map[string]string{ + "authorization": "Bearer " + t.token, + }, nil +} + +// RequireTransportSecurity returns true by default to protect the token in transit. +// Set NB_PROXY_ALLOW_INSECURE=true to allow non-TLS connections (not recommended for production). +func (t proxyAuthToken) RequireTransportSecurity() bool { + return !t.allowInsecure +} + +// WithProxyToken returns a DialOption that sets the proxy access token on each outbound RPC. +func WithProxyToken(token string) grpc.DialOption { + allowInsecure := false + if val := os.Getenv(EnvProxyAllowInsecure); val != "" { + parsed, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("invalid value for %s: %v", EnvProxyAllowInsecure, err) + } else { + allowInsecure = parsed + } + } + return grpc.WithPerRPCCredentials(proxyAuthToken{token: token, allowInsecure: allowInsecure}) +} diff --git a/proxy/server.go b/proxy/server.go index cda6211d2..aed15912e 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -32,6 +32,7 @@ import ( "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/proxy/internal/auth" "github.com/netbirdio/netbird/proxy/internal/debug" + proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/roundtrip" @@ -76,6 +77,8 @@ type Server struct { DebugEndpointAddress string // HealthAddress is the address for the health probe endpoint (default: "localhost:8080"). HealthAddress string + // ProxyToken is the access token for authenticating with the management server. + ProxyToken string } // NotifyStatus sends a status update to management about tunnel connectivity @@ -153,6 +156,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { Timeout: 10 * time.Second, PermitWithoutStream: true, }), + proxygrpc.WithProxyToken(s.ProxyToken), ) if err != nil { return fmt.Errorf("could not create management connection: %w", err)