mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
add rate limit
This commit is contained in:
@@ -3,8 +3,11 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -13,6 +16,7 @@ import (
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
@@ -47,6 +51,10 @@ type GRPCServer struct {
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
syncLimiter *rate.Limiter
|
||||
loginLimiterStore sync.Map
|
||||
loginPeerBooster int
|
||||
loginPeerLimit rate.Limit
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -76,6 +84,41 @@ func NewServer(
|
||||
}
|
||||
}
|
||||
|
||||
multiplier := time.Minute
|
||||
d, e := time.ParseDuration(os.Getenv("NB_LOGIN_RATE"))
|
||||
if e == nil {
|
||||
multiplier = d
|
||||
}
|
||||
|
||||
loginRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_RATE_PER_M"))
|
||||
if loginRatePerS == 0 || err != nil {
|
||||
loginRatePerS = 200
|
||||
}
|
||||
|
||||
loginBurst, err := strconv.Atoi(os.Getenv("NB_LOGIN_BURST"))
|
||||
if loginBurst == 0 || err != nil {
|
||||
loginBurst = 200
|
||||
}
|
||||
log.WithContext(ctx).Infof("login burst limit set to %d", loginBurst)
|
||||
|
||||
loginPeerRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_PEER_RATE_PER_M"))
|
||||
if loginPeerRatePerS == 0 || err != nil {
|
||||
loginPeerRatePerS = 200
|
||||
}
|
||||
log.WithContext(ctx).Infof("login rate limit set to %d/min", loginRatePerS)
|
||||
|
||||
syncRatePerS, err := strconv.Atoi(os.Getenv("NB_SYNC_RATE_PER_M"))
|
||||
if syncRatePerS == 0 || err != nil {
|
||||
syncRatePerS = 20000
|
||||
}
|
||||
log.WithContext(ctx).Infof("sync rate limit set to %d/min", syncRatePerS)
|
||||
|
||||
syncBurst, err := strconv.Atoi(os.Getenv("NB_SYNC_BURST"))
|
||||
if syncBurst == 0 || err != nil {
|
||||
syncBurst = 30000
|
||||
}
|
||||
log.WithContext(ctx).Infof("sync burst limit set to %d", syncBurst)
|
||||
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
@@ -87,6 +130,9 @@ func NewServer(
|
||||
authManager: authManager,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
syncLimiter: rate.NewLimiter(rate.Every(time.Minute/time.Duration(syncRatePerS)), syncBurst),
|
||||
loginPeerLimit: rate.Every(multiplier / time.Duration(loginPeerRatePerS)),
|
||||
loginPeerBooster: loginBurst,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -128,11 +174,18 @@ func getRealIP(ctx context.Context) net.IP {
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
reqStart := time.Now()
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequest()
|
||||
}
|
||||
|
||||
if !s.syncLimiter.Allow() {
|
||||
time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100)))
|
||||
log.Warnf("sync rate limit exceeded for peer %s", req.WgPubKey)
|
||||
return status.Errorf(codes.Internal, "temp rate limit reached")
|
||||
}
|
||||
|
||||
reqStart := time.Now()
|
||||
|
||||
ctx := srv.Context()
|
||||
|
||||
syncReq := &proto.SyncRequest{}
|
||||
@@ -428,15 +481,39 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
|
||||
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
||||
// In case of the successful registration login is also successful
|
||||
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
}
|
||||
|
||||
limiterIface, ok := s.loginLimiterStore.Load(req.WgPubKey)
|
||||
if !ok {
|
||||
// Create new limiter for this peer
|
||||
newLimiter := rate.NewLimiter(s.loginPeerLimit, s.loginPeerBooster)
|
||||
s.loginLimiterStore.Store(req.WgPubKey, newLimiter)
|
||||
|
||||
if !newLimiter.Allow() {
|
||||
time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100)))
|
||||
log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey)
|
||||
return nil, fmt.Errorf("temp rate limit reached (new peer limit)")
|
||||
}
|
||||
} else {
|
||||
// Use existing limiter for this peer
|
||||
limiter := limiterIface.(*rate.Limiter)
|
||||
if !limiter.Allow() {
|
||||
time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100)))
|
||||
log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey)
|
||||
return nil, fmt.Errorf("temp rate limit reached (peer limit)")
|
||||
}
|
||||
}
|
||||
reqStart := time.Now()
|
||||
defer func() {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
}()
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
}
|
||||
//if s.appMetrics != nil {
|
||||
// s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
//}
|
||||
realIP := getRealIP(ctx)
|
||||
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user