diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 2b27f9e0f..86d4e7fab 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -3,8 +3,11 @@ package server import ( "context" "fmt" + "math/rand/v2" "net" "net/netip" + "os" + "strconv" "strings" "sync" "time" @@ -13,6 +16,7 @@ import ( "github.com/golang/protobuf/ptypes/timestamp" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" log "github.com/sirupsen/logrus" + "golang.org/x/time/rate" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" @@ -47,6 +51,10 @@ type GRPCServer struct { ephemeralManager *EphemeralManager peerLocks sync.Map authManager auth.Manager + syncLimiter *rate.Limiter + loginLimiterStore sync.Map + loginPeerBooster int + loginPeerLimit rate.Limit } // NewServer creates a new Management server @@ -76,6 +84,41 @@ func NewServer( } } + multiplier := time.Minute + d, e := time.ParseDuration(os.Getenv("NB_LOGIN_RATE")) + if e == nil { + multiplier = d + } + + loginRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_RATE_PER_M")) + if loginRatePerS == 0 || err != nil { + loginRatePerS = 200 + } + + loginBurst, err := strconv.Atoi(os.Getenv("NB_LOGIN_BURST")) + if loginBurst == 0 || err != nil { + loginBurst = 200 + } + log.WithContext(ctx).Infof("login burst limit set to %d", loginBurst) + + loginPeerRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_PEER_RATE_PER_M")) + if loginPeerRatePerS == 0 || err != nil { + loginPeerRatePerS = 200 + } + log.WithContext(ctx).Infof("login rate limit set to %d/min", loginRatePerS) + + syncRatePerS, err := strconv.Atoi(os.Getenv("NB_SYNC_RATE_PER_M")) + if syncRatePerS == 0 || err != nil { + syncRatePerS = 20000 + } + log.WithContext(ctx).Infof("sync rate limit set to %d/min", syncRatePerS) + + syncBurst, err := strconv.Atoi(os.Getenv("NB_SYNC_BURST")) + if syncBurst == 0 || err != nil { + syncBurst = 30000 + } + log.WithContext(ctx).Infof("sync burst limit set to %d", syncBurst) + return &GRPCServer{ wgKey: key, // peerKey -> event channel @@ -87,6 +130,9 @@ func NewServer( authManager: authManager, appMetrics: appMetrics, ephemeralManager: ephemeralManager, + syncLimiter: rate.NewLimiter(rate.Every(time.Minute/time.Duration(syncRatePerS)), syncBurst), + loginPeerLimit: rate.Every(multiplier / time.Duration(loginPeerRatePerS)), + loginPeerBooster: loginBurst, }, nil } @@ -128,11 +174,18 @@ func getRealIP(ctx context.Context) net.IP { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { - reqStart := time.Now() if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequest() } + if !s.syncLimiter.Allow() { + time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100))) + log.Warnf("sync rate limit exceeded for peer %s", req.WgPubKey) + return status.Errorf(codes.Internal, "temp rate limit reached") + } + + reqStart := time.Now() + ctx := srv.Context() syncReq := &proto.SyncRequest{} @@ -428,15 +481,39 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa // In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. // In case of the successful registration login is also successful func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequest() + } + + limiterIface, ok := s.loginLimiterStore.Load(req.WgPubKey) + if !ok { + // Create new limiter for this peer + newLimiter := rate.NewLimiter(s.loginPeerLimit, s.loginPeerBooster) + s.loginLimiterStore.Store(req.WgPubKey, newLimiter) + + if !newLimiter.Allow() { + time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100))) + log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey) + return nil, fmt.Errorf("temp rate limit reached (new peer limit)") + } + } else { + // Use existing limiter for this peer + limiter := limiterIface.(*rate.Limiter) + if !limiter.Allow() { + time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100))) + log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey) + return nil, fmt.Errorf("temp rate limit reached (peer limit)") + } + } reqStart := time.Now() defer func() { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart)) } }() - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequest() - } + //if s.appMetrics != nil { + // s.appMetrics.GRPCMetrics().CountLoginRequest() + //} realIP := getRealIP(ctx) log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())