mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management, proxy] Add CrowdSec IP reputation integration for reverse proxy (#5722)
This commit is contained in:
@@ -35,7 +35,7 @@ var (
|
||||
)
|
||||
|
||||
var (
|
||||
logLevel string
|
||||
logLevel string
|
||||
debugLogs bool
|
||||
mgmtAddr string
|
||||
addr string
|
||||
@@ -64,6 +64,8 @@ var (
|
||||
supportsCustomPorts bool
|
||||
requireSubdomain bool
|
||||
geoDataDir string
|
||||
crowdsecAPIURL string
|
||||
crowdsecAPIKey string
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
@@ -106,6 +108,8 @@ func init() {
|
||||
rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)")
|
||||
rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)")
|
||||
rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)")
|
||||
rootCmd.Flags().StringVar(&crowdsecAPIURL, "crowdsec-api-url", envStringOrDefault("NB_PROXY_CROWDSEC_API_URL", ""), "CrowdSec LAPI URL for IP reputation checks")
|
||||
rootCmd.Flags().StringVar(&crowdsecAPIKey, "crowdsec-api-key", envStringOrDefault("NB_PROXY_CROWDSEC_API_KEY", ""), "CrowdSec bouncer API key")
|
||||
}
|
||||
|
||||
// Execute runs the root command.
|
||||
@@ -187,6 +191,8 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
MaxDialTimeout: maxDialTimeout,
|
||||
MaxSessionIdleTimeout: maxSessionIdleTimeout,
|
||||
GeoDataDir: geoDataDir,
|
||||
CrowdSecAPIURL: crowdsecAPIURL,
|
||||
CrowdSecAPIKey: crowdsecAPIKey,
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
|
||||
@@ -2,6 +2,7 @@ package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -126,6 +127,7 @@ type logEntry struct {
|
||||
BytesUpload int64
|
||||
BytesDownload int64
|
||||
Protocol Protocol
|
||||
Metadata map[string]string
|
||||
}
|
||||
|
||||
// Protocol identifies the transport protocol of an access log entry.
|
||||
@@ -150,8 +152,10 @@ type L4Entry struct {
|
||||
BytesDownload int64
|
||||
// DenyReason, when non-empty, indicates the connection was denied.
|
||||
// Values match the HTTP auth mechanism strings: "ip_restricted",
|
||||
// "country_restricted", "geo_unavailable".
|
||||
// "country_restricted", "geo_unavailable", "crowdsec_ban", etc.
|
||||
DenyReason string
|
||||
// Metadata carries extra context about the connection (e.g. CrowdSec verdict).
|
||||
Metadata map[string]string
|
||||
}
|
||||
|
||||
// LogL4 sends an access log entry for a layer-4 connection (TCP or UDP).
|
||||
@@ -167,6 +171,7 @@ func (l *Logger) LogL4(entry L4Entry) {
|
||||
DurationMs: entry.DurationMs,
|
||||
BytesUpload: entry.BytesUpload,
|
||||
BytesDownload: entry.BytesDownload,
|
||||
Metadata: maps.Clone(entry.Metadata),
|
||||
}
|
||||
if entry.DenyReason != "" {
|
||||
if !l.allowDenyLog(entry.ServiceID, entry.DenyReason) {
|
||||
@@ -258,6 +263,7 @@ func (l *Logger) log(entry logEntry) {
|
||||
BytesUpload: entry.BytesUpload,
|
||||
BytesDownload: entry.BytesDownload,
|
||||
Protocol: string(entry.Protocol),
|
||||
Metadata: entry.Metadata,
|
||||
},
|
||||
}); err != nil {
|
||||
l.logger.WithFields(log.Fields{
|
||||
|
||||
@@ -82,6 +82,7 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
BytesUpload: bytesUpload,
|
||||
BytesDownload: bytesDownload,
|
||||
Protocol: ProtocolHTTP,
|
||||
Metadata: capturedData.GetMetadata(),
|
||||
}
|
||||
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
|
||||
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceID(), capturedData.GetAccountID())
|
||||
|
||||
@@ -167,6 +167,20 @@ func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request
|
||||
return true
|
||||
}
|
||||
|
||||
if verdict.IsCrowdSec() {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetMetadata("crowdsec_verdict", verdict.String())
|
||||
if config.IPRestrictions.IsObserveOnly(verdict) {
|
||||
cd.SetMetadata("crowdsec_mode", "observe")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.IPRestrictions.IsObserveOnly(verdict) {
|
||||
mw.logger.Debugf("CrowdSec observe: would block %s for %s (%s)", clientIP, r.Host, verdict)
|
||||
return true
|
||||
}
|
||||
|
||||
reason := verdict.String()
|
||||
mw.blockIPRestriction(r, reason)
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
|
||||
@@ -669,7 +669,7 @@ func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil))
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -705,7 +705,7 @@ func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter([]string{"203.0.113.0/24"}, nil, nil, nil))
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -746,7 +746,7 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(nil, nil, []string{"US"}, nil))
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
251
proxy/internal/crowdsec/bouncer.go
Normal file
251
proxy/internal/crowdsec/bouncer.go
Normal file
@@ -0,0 +1,251 @@
|
||||
// Package crowdsec provides a CrowdSec stream bouncer that maintains a local
|
||||
// decision cache for IP reputation checks.
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||
csbouncer "github.com/crowdsecurity/go-cs-bouncer"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
)
|
||||
|
||||
// Bouncer wraps a CrowdSec StreamBouncer, maintaining a local cache of
|
||||
// active decisions for fast IP lookups. It implements restrict.CrowdSecChecker.
|
||||
type Bouncer struct {
|
||||
mu sync.RWMutex
|
||||
ips map[netip.Addr]*restrict.CrowdSecDecision
|
||||
prefixes map[netip.Prefix]*restrict.CrowdSecDecision
|
||||
ready atomic.Bool
|
||||
|
||||
apiURL string
|
||||
apiKey string
|
||||
tickerInterval time.Duration
|
||||
logger *log.Entry
|
||||
|
||||
// lifeMu protects cancel and done from concurrent Start/Stop calls.
|
||||
lifeMu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// compile-time check
|
||||
var _ restrict.CrowdSecChecker = (*Bouncer)(nil)
|
||||
|
||||
// NewBouncer creates a bouncer but does not start the stream.
|
||||
func NewBouncer(apiURL, apiKey string, logger *log.Entry) *Bouncer {
|
||||
return &Bouncer{
|
||||
apiURL: apiURL,
|
||||
apiKey: apiKey,
|
||||
logger: logger,
|
||||
ips: make(map[netip.Addr]*restrict.CrowdSecDecision),
|
||||
prefixes: make(map[netip.Prefix]*restrict.CrowdSecDecision),
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the background goroutine that streams decisions from the
|
||||
// CrowdSec LAPI. The stream runs until Stop is called or ctx is cancelled.
|
||||
func (b *Bouncer) Start(ctx context.Context) error {
|
||||
interval := b.tickerInterval
|
||||
if interval == 0 {
|
||||
interval = 10 * time.Second
|
||||
}
|
||||
stream := &csbouncer.StreamBouncer{
|
||||
APIKey: b.apiKey,
|
||||
APIUrl: b.apiURL,
|
||||
TickerInterval: interval.String(),
|
||||
UserAgent: "netbird-proxy/1.0",
|
||||
Scopes: []string{"ip", "range"},
|
||||
RetryInitialConnect: true,
|
||||
}
|
||||
|
||||
b.logger.Infof("connecting to CrowdSec LAPI at %s", b.apiURL)
|
||||
|
||||
if err := stream.Init(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Reset state from any previous run.
|
||||
b.mu.Lock()
|
||||
b.ips = make(map[netip.Addr]*restrict.CrowdSecDecision)
|
||||
b.prefixes = make(map[netip.Prefix]*restrict.CrowdSecDecision)
|
||||
b.mu.Unlock()
|
||||
b.ready.Store(false)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
done := make(chan struct{})
|
||||
|
||||
b.lifeMu.Lock()
|
||||
if b.cancel != nil {
|
||||
b.lifeMu.Unlock()
|
||||
cancel()
|
||||
return errors.New("bouncer already started")
|
||||
}
|
||||
b.cancel = cancel
|
||||
b.done = done
|
||||
b.lifeMu.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := stream.Run(ctx); err != nil && ctx.Err() == nil {
|
||||
b.logger.Errorf("CrowdSec stream ended: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
b.consumeStream(ctx, stream)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop cancels the stream and waits for all goroutines to finish.
|
||||
func (b *Bouncer) Stop() {
|
||||
b.lifeMu.Lock()
|
||||
cancel := b.cancel
|
||||
done := b.done
|
||||
b.cancel = nil
|
||||
b.lifeMu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// Ready returns true after the first batch of decisions has been processed.
|
||||
func (b *Bouncer) Ready() bool {
|
||||
return b.ready.Load()
|
||||
}
|
||||
|
||||
// CheckIP looks up addr in the local decision cache. Returns nil if no
|
||||
// active decision exists for the address.
|
||||
//
|
||||
// Prefix lookups are O(1): instead of scanning all stored prefixes, we
|
||||
// probe the map for every possible containing prefix of the address
|
||||
// (at most 33 for IPv4, 129 for IPv6).
|
||||
func (b *Bouncer) CheckIP(addr netip.Addr) *restrict.CrowdSecDecision {
|
||||
addr = addr.Unmap()
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
if d, ok := b.ips[addr]; ok {
|
||||
return d
|
||||
}
|
||||
|
||||
maxBits := 32
|
||||
if addr.Is6() {
|
||||
maxBits = 128
|
||||
}
|
||||
// Walk from most-specific to least-specific prefix so the narrowest
|
||||
// matching decision wins when ranges overlap.
|
||||
for bits := maxBits; bits >= 0; bits-- {
|
||||
prefix := netip.PrefixFrom(addr, bits).Masked()
|
||||
if d, ok := b.prefixes[prefix]; ok {
|
||||
return d
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Bouncer) consumeStream(ctx context.Context, stream *csbouncer.StreamBouncer) {
|
||||
first := true
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case resp, ok := <-stream.Stream:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
b.applyDeleted(resp.Deleted)
|
||||
b.applyNew(resp.New)
|
||||
b.mu.Unlock()
|
||||
|
||||
if first {
|
||||
b.ready.Store(true)
|
||||
b.logger.Info("CrowdSec bouncer synced initial decisions")
|
||||
first = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bouncer) applyDeleted(decisions []*models.Decision) {
|
||||
for _, d := range decisions {
|
||||
if d.Value == nil || d.Scope == nil {
|
||||
continue
|
||||
}
|
||||
value := *d.Value
|
||||
|
||||
if strings.ToLower(*d.Scope) == "range" || strings.Contains(value, "/") {
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
b.logger.Debugf("skip unparsable CrowdSec range deletion %q: %v", value, err)
|
||||
continue
|
||||
}
|
||||
prefix = normalizePrefix(prefix)
|
||||
delete(b.prefixes, prefix)
|
||||
} else {
|
||||
addr, err := netip.ParseAddr(value)
|
||||
if err != nil {
|
||||
b.logger.Debugf("skip unparsable CrowdSec IP deletion %q: %v", value, err)
|
||||
continue
|
||||
}
|
||||
delete(b.ips, addr.Unmap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bouncer) applyNew(decisions []*models.Decision) {
|
||||
for _, d := range decisions {
|
||||
if d.Value == nil || d.Type == nil || d.Scope == nil {
|
||||
continue
|
||||
}
|
||||
dec := &restrict.CrowdSecDecision{Type: restrict.DecisionType(*d.Type)}
|
||||
value := *d.Value
|
||||
|
||||
if strings.ToLower(*d.Scope) == "range" || strings.Contains(value, "/") {
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
b.logger.Debugf("skip unparsable CrowdSec range %q: %v", value, err)
|
||||
continue
|
||||
}
|
||||
prefix = normalizePrefix(prefix)
|
||||
b.prefixes[prefix] = dec
|
||||
} else {
|
||||
addr, err := netip.ParseAddr(value)
|
||||
if err != nil {
|
||||
b.logger.Debugf("skip unparsable CrowdSec IP %q: %v", value, err)
|
||||
continue
|
||||
}
|
||||
b.ips[addr.Unmap()] = dec
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalizePrefix unmaps v4-mapped-v6 addresses and zeros host bits so
|
||||
// the prefix is a valid map key that matches CheckIP's probe logic.
|
||||
func normalizePrefix(p netip.Prefix) netip.Prefix {
|
||||
return netip.PrefixFrom(p.Addr().Unmap(), p.Bits()).Masked()
|
||||
}
|
||||
337
proxy/internal/crowdsec/bouncer_test.go
Normal file
337
proxy/internal/crowdsec/bouncer_test.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
)
|
||||
|
||||
func TestBouncer_CheckIP_Empty(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
b.ready.Store(true)
|
||||
|
||||
assert.Nil(t, b.CheckIP(netip.MustParseAddr("1.2.3.4")))
|
||||
}
|
||||
|
||||
func TestBouncer_CheckIP_ExactMatch(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
b.ready.Store(true)
|
||||
b.ips[netip.MustParseAddr("10.0.0.1")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan}
|
||||
|
||||
d := b.CheckIP(netip.MustParseAddr("10.0.0.1"))
|
||||
require.NotNil(t, d)
|
||||
assert.Equal(t, restrict.DecisionBan, d.Type)
|
||||
|
||||
assert.Nil(t, b.CheckIP(netip.MustParseAddr("10.0.0.2")))
|
||||
}
|
||||
|
||||
func TestBouncer_CheckIP_PrefixMatch(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
b.ready.Store(true)
|
||||
b.prefixes[netip.MustParsePrefix("192.168.1.0/24")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan}
|
||||
|
||||
d := b.CheckIP(netip.MustParseAddr("192.168.1.100"))
|
||||
require.NotNil(t, d)
|
||||
assert.Equal(t, restrict.DecisionBan, d.Type)
|
||||
|
||||
assert.Nil(t, b.CheckIP(netip.MustParseAddr("192.168.2.1")))
|
||||
}
|
||||
|
||||
func TestBouncer_CheckIP_UnmapsV4InV6(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
b.ready.Store(true)
|
||||
b.ips[netip.MustParseAddr("10.0.0.1")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan}
|
||||
|
||||
d := b.CheckIP(netip.MustParseAddr("::ffff:10.0.0.1"))
|
||||
require.NotNil(t, d)
|
||||
assert.Equal(t, restrict.DecisionBan, d.Type)
|
||||
}
|
||||
|
||||
func TestBouncer_Ready(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
assert.False(t, b.Ready())
|
||||
|
||||
b.ready.Store(true)
|
||||
assert.True(t, b.Ready())
|
||||
}
|
||||
|
||||
func TestBouncer_CheckIP_ExactBeforePrefix(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
b.ready.Store(true)
|
||||
b.ips[netip.MustParseAddr("10.0.0.1")] = &restrict.CrowdSecDecision{Type: restrict.DecisionCaptcha}
|
||||
b.prefixes[netip.MustParsePrefix("10.0.0.0/8")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan}
|
||||
|
||||
d := b.CheckIP(netip.MustParseAddr("10.0.0.1"))
|
||||
require.NotNil(t, d)
|
||||
assert.Equal(t, restrict.DecisionCaptcha, d.Type)
|
||||
|
||||
d2 := b.CheckIP(netip.MustParseAddr("10.0.0.2"))
|
||||
require.NotNil(t, d2)
|
||||
assert.Equal(t, restrict.DecisionBan, d2.Type)
|
||||
}
|
||||
|
||||
func TestBouncer_ApplyNew_IP(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
|
||||
b.applyNew(makeDecisions(
|
||||
decision{scope: "ip", value: "1.2.3.4", dtype: "ban", scenario: "test/brute"},
|
||||
decision{scope: "ip", value: "5.6.7.8", dtype: "captcha", scenario: "test/crawl"},
|
||||
))
|
||||
|
||||
require.Len(t, b.ips, 2)
|
||||
assert.Equal(t, restrict.DecisionBan, b.ips[netip.MustParseAddr("1.2.3.4")].Type)
|
||||
assert.Equal(t, restrict.DecisionCaptcha, b.ips[netip.MustParseAddr("5.6.7.8")].Type)
|
||||
}
|
||||
|
||||
func TestBouncer_ApplyNew_Range(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
|
||||
b.applyNew(makeDecisions(
|
||||
decision{scope: "range", value: "10.0.0.0/8", dtype: "ban"},
|
||||
))
|
||||
|
||||
require.Len(t, b.prefixes, 1)
|
||||
assert.NotNil(t, b.prefixes[netip.MustParsePrefix("10.0.0.0/8")])
|
||||
}
|
||||
|
||||
func TestBouncer_ApplyDeleted_IP(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
b.ips[netip.MustParseAddr("1.2.3.4")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan}
|
||||
b.ips[netip.MustParseAddr("5.6.7.8")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan}
|
||||
|
||||
b.applyDeleted(makeDecisions(
|
||||
decision{scope: "ip", value: "1.2.3.4", dtype: "ban"},
|
||||
))
|
||||
|
||||
assert.Len(t, b.ips, 1)
|
||||
assert.Nil(t, b.ips[netip.MustParseAddr("1.2.3.4")])
|
||||
assert.NotNil(t, b.ips[netip.MustParseAddr("5.6.7.8")])
|
||||
}
|
||||
|
||||
func TestBouncer_ApplyDeleted_Range(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
b.prefixes[netip.MustParsePrefix("10.0.0.0/8")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan}
|
||||
b.prefixes[netip.MustParsePrefix("192.168.0.0/16")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan}
|
||||
|
||||
b.applyDeleted(makeDecisions(
|
||||
decision{scope: "range", value: "10.0.0.0/8", dtype: "ban"},
|
||||
))
|
||||
|
||||
require.Len(t, b.prefixes, 1)
|
||||
assert.NotNil(t, b.prefixes[netip.MustParsePrefix("192.168.0.0/16")])
|
||||
}
|
||||
|
||||
func TestBouncer_ApplyNew_OverwritesExisting(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
b.ips[netip.MustParseAddr("1.2.3.4")] = &restrict.CrowdSecDecision{Type: restrict.DecisionBan}
|
||||
|
||||
b.applyNew(makeDecisions(
|
||||
decision{scope: "ip", value: "1.2.3.4", dtype: "captcha"},
|
||||
))
|
||||
|
||||
assert.Equal(t, restrict.DecisionCaptcha, b.ips[netip.MustParseAddr("1.2.3.4")].Type)
|
||||
}
|
||||
|
||||
func TestBouncer_ApplyNew_SkipsInvalid(t *testing.T) {
|
||||
b := newTestBouncer()
|
||||
|
||||
b.applyNew(makeDecisions(
|
||||
decision{scope: "ip", value: "not-an-ip", dtype: "ban"},
|
||||
decision{scope: "range", value: "also-not-valid", dtype: "ban"},
|
||||
))
|
||||
|
||||
assert.Empty(t, b.ips)
|
||||
assert.Empty(t, b.prefixes)
|
||||
}
|
||||
|
||||
// TestBouncer_StreamIntegration tests the full flow: fake LAPI → StreamBouncer → Bouncer cache → CheckIP.
|
||||
func TestBouncer_StreamIntegration(t *testing.T) {
|
||||
lapi := newFakeLAPI()
|
||||
ts := httptest.NewServer(lapi)
|
||||
defer ts.Close()
|
||||
|
||||
// Seed the LAPI with initial decisions.
|
||||
lapi.setDecisions(
|
||||
decision{scope: "ip", value: "1.2.3.4", dtype: "ban", scenario: "crowdsecurity/ssh-bf"},
|
||||
decision{scope: "range", value: "10.0.0.0/8", dtype: "ban", scenario: "crowdsecurity/http-probing"},
|
||||
decision{scope: "ip", value: "5.5.5.5", dtype: "captcha", scenario: "crowdsecurity/http-crawl"},
|
||||
)
|
||||
|
||||
b := NewBouncer(ts.URL, "test-key", log.NewEntry(log.StandardLogger()))
|
||||
b.tickerInterval = 200 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, b.Start(ctx))
|
||||
defer b.Stop()
|
||||
|
||||
// Wait for initial sync.
|
||||
require.Eventually(t, b.Ready, 5*time.Second, 50*time.Millisecond, "bouncer should become ready")
|
||||
|
||||
// Verify decisions are cached.
|
||||
d := b.CheckIP(netip.MustParseAddr("1.2.3.4"))
|
||||
require.NotNil(t, d, "1.2.3.4 should be banned")
|
||||
assert.Equal(t, restrict.DecisionBan, d.Type)
|
||||
|
||||
d2 := b.CheckIP(netip.MustParseAddr("10.1.2.3"))
|
||||
require.NotNil(t, d2, "10.1.2.3 should match range ban")
|
||||
assert.Equal(t, restrict.DecisionBan, d2.Type)
|
||||
|
||||
d3 := b.CheckIP(netip.MustParseAddr("5.5.5.5"))
|
||||
require.NotNil(t, d3, "5.5.5.5 should have captcha")
|
||||
assert.Equal(t, restrict.DecisionCaptcha, d3.Type)
|
||||
|
||||
assert.Nil(t, b.CheckIP(netip.MustParseAddr("9.9.9.9")), "unknown IP should be nil")
|
||||
|
||||
// Simulate a delta update: delete one IP, add a new one.
|
||||
lapi.setDelta(
|
||||
[]decision{{scope: "ip", value: "1.2.3.4", dtype: "ban"}},
|
||||
[]decision{{scope: "ip", value: "2.3.4.5", dtype: "throttle", scenario: "crowdsecurity/http-flood"}},
|
||||
)
|
||||
|
||||
// Wait for the delta to be picked up.
|
||||
require.Eventually(t, func() bool {
|
||||
return b.CheckIP(netip.MustParseAddr("2.3.4.5")) != nil
|
||||
}, 5*time.Second, 50*time.Millisecond, "new decision should appear")
|
||||
|
||||
assert.Nil(t, b.CheckIP(netip.MustParseAddr("1.2.3.4")), "deleted decision should be gone")
|
||||
|
||||
d4 := b.CheckIP(netip.MustParseAddr("2.3.4.5"))
|
||||
require.NotNil(t, d4)
|
||||
assert.Equal(t, restrict.DecisionThrottle, d4.Type)
|
||||
|
||||
// Range ban should still be active.
|
||||
assert.NotNil(t, b.CheckIP(netip.MustParseAddr("10.99.99.99")))
|
||||
}
|
||||
|
||||
// Helpers
|
||||
|
||||
func newTestBouncer() *Bouncer {
|
||||
return &Bouncer{
|
||||
ips: make(map[netip.Addr]*restrict.CrowdSecDecision),
|
||||
prefixes: make(map[netip.Prefix]*restrict.CrowdSecDecision),
|
||||
logger: log.NewEntry(log.StandardLogger()),
|
||||
}
|
||||
}
|
||||
|
||||
type decision struct {
|
||||
scope string
|
||||
value string
|
||||
dtype string
|
||||
scenario string
|
||||
}
|
||||
|
||||
func makeDecisions(decs ...decision) []*models.Decision {
|
||||
out := make([]*models.Decision, len(decs))
|
||||
for i, d := range decs {
|
||||
out[i] = &models.Decision{
|
||||
Scope: strPtr(d.scope),
|
||||
Value: strPtr(d.value),
|
||||
Type: strPtr(d.dtype),
|
||||
Scenario: strPtr(d.scenario),
|
||||
Duration: strPtr("1h"),
|
||||
Origin: strPtr("cscli"),
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func strPtr(s string) *string { return &s }
|
||||
|
||||
// fakeLAPI is a minimal fake CrowdSec LAPI that serves /v1/decisions/stream.
|
||||
type fakeLAPI struct {
|
||||
mu sync.Mutex
|
||||
initial []decision
|
||||
newDelta []decision
|
||||
delDelta []decision
|
||||
served bool // true after the initial snapshot has been served
|
||||
}
|
||||
|
||||
func newFakeLAPI() *fakeLAPI {
|
||||
return &fakeLAPI{}
|
||||
}
|
||||
|
||||
func (f *fakeLAPI) setDecisions(decs ...decision) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.initial = decs
|
||||
f.served = false
|
||||
}
|
||||
|
||||
func (f *fakeLAPI) setDelta(deleted, added []decision) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.delDelta = deleted
|
||||
f.newDelta = added
|
||||
}
|
||||
|
||||
func (f *fakeLAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/decisions/stream" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
resp := streamResponse{}
|
||||
|
||||
if !f.served {
|
||||
for _, d := range f.initial {
|
||||
resp.New = append(resp.New, toLAPIDecision(d))
|
||||
}
|
||||
f.served = true
|
||||
} else {
|
||||
for _, d := range f.delDelta {
|
||||
resp.Deleted = append(resp.Deleted, toLAPIDecision(d))
|
||||
}
|
||||
for _, d := range f.newDelta {
|
||||
resp.New = append(resp.New, toLAPIDecision(d))
|
||||
}
|
||||
// Clear delta after serving once.
|
||||
f.delDelta = nil
|
||||
f.newDelta = nil
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp) //nolint:errcheck
|
||||
}
|
||||
|
||||
// streamResponse mirrors the CrowdSec LAPI /v1/decisions/stream JSON structure.
|
||||
type streamResponse struct {
|
||||
New []*lapiDecision `json:"new"`
|
||||
Deleted []*lapiDecision `json:"deleted"`
|
||||
}
|
||||
|
||||
type lapiDecision struct {
|
||||
Duration *string `json:"duration"`
|
||||
Origin *string `json:"origin"`
|
||||
Scenario *string `json:"scenario"`
|
||||
Scope *string `json:"scope"`
|
||||
Type *string `json:"type"`
|
||||
Value *string `json:"value"`
|
||||
}
|
||||
|
||||
func toLAPIDecision(d decision) *lapiDecision {
|
||||
return &lapiDecision{
|
||||
Duration: strPtr("1h"),
|
||||
Origin: strPtr("cscli"),
|
||||
Scenario: strPtr(d.scenario),
|
||||
Scope: strPtr(d.scope),
|
||||
Type: strPtr(d.dtype),
|
||||
Value: strPtr(d.value),
|
||||
}
|
||||
}
|
||||
103
proxy/internal/crowdsec/registry.go
Normal file
103
proxy/internal/crowdsec/registry.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
// Registry manages a single shared Bouncer instance with reference counting.
|
||||
// The bouncer starts when the first service acquires it and stops when the
|
||||
// last service releases it.
|
||||
type Registry struct {
|
||||
mu sync.Mutex
|
||||
bouncer *Bouncer
|
||||
refs map[types.ServiceID]struct{}
|
||||
apiURL string
|
||||
apiKey string
|
||||
logger *log.Entry
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRegistry creates a registry. The bouncer is not started until Acquire is called.
|
||||
func NewRegistry(apiURL, apiKey string, logger *log.Entry) *Registry {
|
||||
return &Registry{
|
||||
apiURL: apiURL,
|
||||
apiKey: apiKey,
|
||||
logger: logger,
|
||||
refs: make(map[types.ServiceID]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Available returns true when the LAPI URL and API key are configured.
|
||||
func (r *Registry) Available() bool {
|
||||
return r.apiURL != "" && r.apiKey != ""
|
||||
}
|
||||
|
||||
// Acquire registers svcID as a consumer and starts the bouncer if this is the
|
||||
// first consumer. Returns the shared Bouncer (which implements the restrict
|
||||
// package's CrowdSecChecker interface). Returns nil if not Available.
|
||||
func (r *Registry) Acquire(svcID types.ServiceID) *Bouncer {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if !r.Available() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, exists := r.refs[svcID]; exists {
|
||||
return r.bouncer
|
||||
}
|
||||
|
||||
if r.bouncer == nil {
|
||||
r.startLocked()
|
||||
}
|
||||
|
||||
// startLocked may fail, leaving r.bouncer nil.
|
||||
if r.bouncer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.refs[svcID] = struct{}{}
|
||||
return r.bouncer
|
||||
}
|
||||
|
||||
// Release removes svcID as a consumer. Stops the bouncer when the last
|
||||
// consumer releases.
|
||||
func (r *Registry) Release(svcID types.ServiceID) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
delete(r.refs, svcID)
|
||||
|
||||
if len(r.refs) == 0 && r.bouncer != nil {
|
||||
r.stopLocked()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) startLocked() {
|
||||
b := NewBouncer(r.apiURL, r.apiKey, r.logger)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
r.cancel = cancel
|
||||
|
||||
if err := b.Start(ctx); err != nil {
|
||||
r.logger.Errorf("failed to start CrowdSec bouncer: %v", err)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
r.bouncer = b
|
||||
r.logger.Info("CrowdSec bouncer started")
|
||||
}
|
||||
|
||||
func (r *Registry) stopLocked() {
|
||||
r.bouncer.Stop()
|
||||
r.cancel()
|
||||
r.bouncer = nil
|
||||
r.cancel = nil
|
||||
r.logger.Info("CrowdSec bouncer stopped")
|
||||
}
|
||||
66
proxy/internal/crowdsec/registry_test.go
Normal file
66
proxy/internal/crowdsec/registry_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
func TestRegistry_Available(t *testing.T) {
|
||||
r := NewRegistry("http://localhost:8080/", "test-key", log.NewEntry(log.StandardLogger()))
|
||||
assert.True(t, r.Available())
|
||||
|
||||
r2 := NewRegistry("", "", log.NewEntry(log.StandardLogger()))
|
||||
assert.False(t, r2.Available())
|
||||
|
||||
r3 := NewRegistry("http://localhost:8080/", "", log.NewEntry(log.StandardLogger()))
|
||||
assert.False(t, r3.Available())
|
||||
}
|
||||
|
||||
func TestRegistry_Acquire_NotAvailable(t *testing.T) {
|
||||
r := NewRegistry("", "", log.NewEntry(log.StandardLogger()))
|
||||
b := r.Acquire("svc-1")
|
||||
assert.Nil(t, b)
|
||||
}
|
||||
|
||||
func TestRegistry_Acquire_Idempotent(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
|
||||
b1 := r.Acquire("svc-1")
|
||||
// Can't start without a real LAPI, but we can verify the ref tracking.
|
||||
// The bouncer will be nil because Start fails, but the ref is tracked.
|
||||
_ = b1
|
||||
|
||||
assert.Len(t, r.refs, 1)
|
||||
|
||||
// Second acquire of same service should not add another ref.
|
||||
r.Acquire("svc-1")
|
||||
assert.Len(t, r.refs, 1)
|
||||
}
|
||||
|
||||
func TestRegistry_Release_Removes(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
r.refs[types.ServiceID("svc-1")] = struct{}{}
|
||||
|
||||
r.Release("svc-1")
|
||||
assert.Empty(t, r.refs)
|
||||
}
|
||||
|
||||
func TestRegistry_Release_Noop(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
// Releasing a service that was never acquired should not panic.
|
||||
r.Release("nonexistent")
|
||||
assert.Empty(t, r.refs)
|
||||
}
|
||||
|
||||
func newTestRegistry() *Registry {
|
||||
return &Registry{
|
||||
apiURL: "http://localhost:8080/",
|
||||
apiKey: "test-key",
|
||||
logger: log.NewEntry(log.StandardLogger()),
|
||||
refs: make(map[types.ServiceID]struct{}),
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@@ -52,6 +53,7 @@ type CapturedData struct {
|
||||
clientIP netip.Addr
|
||||
userID string
|
||||
authMethod string
|
||||
metadata map[string]string
|
||||
}
|
||||
|
||||
// NewCapturedData creates a CapturedData with the given request ID.
|
||||
@@ -150,6 +152,23 @@ func (c *CapturedData) GetAuthMethod() string {
|
||||
return c.authMethod
|
||||
}
|
||||
|
||||
// SetMetadata sets a key-value pair in the metadata map.
|
||||
func (c *CapturedData) SetMetadata(key, value string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.metadata == nil {
|
||||
c.metadata = make(map[string]string)
|
||||
}
|
||||
c.metadata[key] = value
|
||||
}
|
||||
|
||||
// GetMetadata returns a copy of the metadata map.
|
||||
func (c *CapturedData) GetMetadata() map[string]string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return maps.Clone(c.metadata)
|
||||
}
|
||||
|
||||
// WithCapturedData adds a CapturedData struct to the context.
|
||||
func WithCapturedData(ctx context.Context, data *CapturedData) context.Context {
|
||||
return context.WithValue(ctx, capturedDataKey, data)
|
||||
|
||||
@@ -12,12 +12,44 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/geolocation"
|
||||
)
|
||||
|
||||
// defaultLogger is used when no logger is provided to ParseFilter.
|
||||
var defaultLogger = log.NewEntry(log.StandardLogger())
|
||||
|
||||
// GeoResolver resolves an IP address to geographic information.
|
||||
type GeoResolver interface {
|
||||
LookupAddr(addr netip.Addr) geolocation.Result
|
||||
Available() bool
|
||||
}
|
||||
|
||||
// DecisionType is the type of CrowdSec remediation action.
|
||||
type DecisionType string
|
||||
|
||||
const (
|
||||
DecisionBan DecisionType = "ban"
|
||||
DecisionCaptcha DecisionType = "captcha"
|
||||
DecisionThrottle DecisionType = "throttle"
|
||||
)
|
||||
|
||||
// CrowdSecDecision holds the type of a CrowdSec decision.
|
||||
type CrowdSecDecision struct {
|
||||
Type DecisionType
|
||||
}
|
||||
|
||||
// CrowdSecChecker queries CrowdSec decisions for an IP address.
|
||||
type CrowdSecChecker interface {
|
||||
CheckIP(addr netip.Addr) *CrowdSecDecision
|
||||
Ready() bool
|
||||
}
|
||||
|
||||
// CrowdSecMode is the per-service enforcement mode.
|
||||
type CrowdSecMode string
|
||||
|
||||
const (
|
||||
CrowdSecOff CrowdSecMode = ""
|
||||
CrowdSecEnforce CrowdSecMode = "enforce"
|
||||
CrowdSecObserve CrowdSecMode = "observe"
|
||||
)
|
||||
|
||||
// Filter evaluates IP restrictions. CIDR checks are performed first
|
||||
// (cheap), followed by country lookups (more expensive) only when needed.
|
||||
type Filter struct {
|
||||
@@ -25,32 +57,55 @@ type Filter struct {
|
||||
BlockedCIDRs []netip.Prefix
|
||||
AllowedCountries []string
|
||||
BlockedCountries []string
|
||||
CrowdSec CrowdSecChecker
|
||||
CrowdSecMode CrowdSecMode
|
||||
}
|
||||
|
||||
// ParseFilter builds a Filter from the raw string slices. Returns nil
|
||||
// if all slices are empty.
|
||||
func ParseFilter(allowedCIDRs, blockedCIDRs, allowedCountries, blockedCountries []string) *Filter {
|
||||
if len(allowedCIDRs) == 0 && len(blockedCIDRs) == 0 &&
|
||||
len(allowedCountries) == 0 && len(blockedCountries) == 0 {
|
||||
// FilterConfig holds the raw configuration for building a Filter.
|
||||
type FilterConfig struct {
|
||||
AllowedCIDRs []string
|
||||
BlockedCIDRs []string
|
||||
AllowedCountries []string
|
||||
BlockedCountries []string
|
||||
CrowdSec CrowdSecChecker
|
||||
CrowdSecMode CrowdSecMode
|
||||
Logger *log.Entry
|
||||
}
|
||||
|
||||
// ParseFilter builds a Filter from the config. Returns nil if no restrictions
|
||||
// are configured.
|
||||
func ParseFilter(cfg FilterConfig) *Filter {
|
||||
hasCS := cfg.CrowdSecMode == CrowdSecEnforce || cfg.CrowdSecMode == CrowdSecObserve
|
||||
if len(cfg.AllowedCIDRs) == 0 && len(cfg.BlockedCIDRs) == 0 &&
|
||||
len(cfg.AllowedCountries) == 0 && len(cfg.BlockedCountries) == 0 && !hasCS {
|
||||
return nil
|
||||
}
|
||||
|
||||
f := &Filter{
|
||||
AllowedCountries: normalizeCountryCodes(allowedCountries),
|
||||
BlockedCountries: normalizeCountryCodes(blockedCountries),
|
||||
logger := cfg.Logger
|
||||
if logger == nil {
|
||||
logger = defaultLogger
|
||||
}
|
||||
for _, cidr := range allowedCIDRs {
|
||||
|
||||
f := &Filter{
|
||||
AllowedCountries: normalizeCountryCodes(cfg.AllowedCountries),
|
||||
BlockedCountries: normalizeCountryCodes(cfg.BlockedCountries),
|
||||
}
|
||||
if hasCS {
|
||||
f.CrowdSec = cfg.CrowdSec
|
||||
f.CrowdSecMode = cfg.CrowdSecMode
|
||||
}
|
||||
for _, cidr := range cfg.AllowedCIDRs {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
log.Warnf("skip invalid allowed CIDR %q: %v", cidr, err)
|
||||
logger.Warnf("skip invalid allowed CIDR %q: %v", cidr, err)
|
||||
continue
|
||||
}
|
||||
f.AllowedCIDRs = append(f.AllowedCIDRs, prefix.Masked())
|
||||
}
|
||||
for _, cidr := range blockedCIDRs {
|
||||
for _, cidr := range cfg.BlockedCIDRs {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
log.Warnf("skip invalid blocked CIDR %q: %v", cidr, err)
|
||||
logger.Warnf("skip invalid blocked CIDR %q: %v", cidr, err)
|
||||
continue
|
||||
}
|
||||
f.BlockedCIDRs = append(f.BlockedCIDRs, prefix.Masked())
|
||||
@@ -82,6 +137,15 @@ const (
|
||||
// DenyGeoUnavailable indicates that country restrictions are configured
|
||||
// but the geo lookup is unavailable.
|
||||
DenyGeoUnavailable
|
||||
// DenyCrowdSecBan indicates a CrowdSec "ban" decision.
|
||||
DenyCrowdSecBan
|
||||
// DenyCrowdSecCaptcha indicates a CrowdSec "captcha" decision.
|
||||
DenyCrowdSecCaptcha
|
||||
// DenyCrowdSecThrottle indicates a CrowdSec "throttle" decision.
|
||||
DenyCrowdSecThrottle
|
||||
// DenyCrowdSecUnavailable indicates enforce mode but the bouncer has not
|
||||
// completed its initial sync.
|
||||
DenyCrowdSecUnavailable
|
||||
)
|
||||
|
||||
// String returns the deny reason string matching the HTTP auth mechanism names.
|
||||
@@ -95,14 +159,42 @@ func (v Verdict) String() string {
|
||||
return "country_restricted"
|
||||
case DenyGeoUnavailable:
|
||||
return "geo_unavailable"
|
||||
case DenyCrowdSecBan:
|
||||
return "crowdsec_ban"
|
||||
case DenyCrowdSecCaptcha:
|
||||
return "crowdsec_captcha"
|
||||
case DenyCrowdSecThrottle:
|
||||
return "crowdsec_throttle"
|
||||
case DenyCrowdSecUnavailable:
|
||||
return "crowdsec_unavailable"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// IsCrowdSec returns true when the verdict originates from a CrowdSec check.
|
||||
func (v Verdict) IsCrowdSec() bool {
|
||||
switch v {
|
||||
case DenyCrowdSecBan, DenyCrowdSecCaptcha, DenyCrowdSecThrottle, DenyCrowdSecUnavailable:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsObserveOnly returns true when v is a CrowdSec verdict and the filter is in
|
||||
// observe mode. Callers should log the verdict but not block the request.
|
||||
func (f *Filter) IsObserveOnly(v Verdict) bool {
|
||||
if f == nil {
|
||||
return false
|
||||
}
|
||||
return v.IsCrowdSec() && f.CrowdSecMode == CrowdSecObserve
|
||||
}
|
||||
|
||||
// Check evaluates whether addr is permitted. CIDR rules are evaluated
|
||||
// first because they are O(n) prefix comparisons. Country rules run
|
||||
// only when CIDR checks pass and require a geo lookup.
|
||||
// only when CIDR checks pass and require a geo lookup. CrowdSec checks
|
||||
// run last.
|
||||
func (f *Filter) Check(addr netip.Addr, geo GeoResolver) Verdict {
|
||||
if f == nil {
|
||||
return Allow
|
||||
@@ -115,7 +207,10 @@ func (f *Filter) Check(addr netip.Addr, geo GeoResolver) Verdict {
|
||||
if v := f.checkCIDR(addr); v != Allow {
|
||||
return v
|
||||
}
|
||||
return f.checkCountry(addr, geo)
|
||||
if v := f.checkCountry(addr, geo); v != Allow {
|
||||
return v
|
||||
}
|
||||
return f.checkCrowdSec(addr)
|
||||
}
|
||||
|
||||
func (f *Filter) checkCIDR(addr netip.Addr) Verdict {
|
||||
@@ -173,11 +268,48 @@ func (f *Filter) checkCountry(addr netip.Addr, geo GeoResolver) Verdict {
|
||||
return Allow
|
||||
}
|
||||
|
||||
func (f *Filter) checkCrowdSec(addr netip.Addr) Verdict {
|
||||
if f.CrowdSecMode == CrowdSecOff {
|
||||
return Allow
|
||||
}
|
||||
|
||||
// Checker nil with enforce means CrowdSec was requested but the proxy
|
||||
// has no LAPI configured. Fail-closed.
|
||||
if f.CrowdSec == nil {
|
||||
if f.CrowdSecMode == CrowdSecEnforce {
|
||||
return DenyCrowdSecUnavailable
|
||||
}
|
||||
return Allow
|
||||
}
|
||||
|
||||
if !f.CrowdSec.Ready() {
|
||||
if f.CrowdSecMode == CrowdSecEnforce {
|
||||
return DenyCrowdSecUnavailable
|
||||
}
|
||||
return Allow
|
||||
}
|
||||
|
||||
d := f.CrowdSec.CheckIP(addr)
|
||||
if d == nil {
|
||||
return Allow
|
||||
}
|
||||
|
||||
switch d.Type {
|
||||
case DecisionCaptcha:
|
||||
return DenyCrowdSecCaptcha
|
||||
case DecisionThrottle:
|
||||
return DenyCrowdSecThrottle
|
||||
default:
|
||||
return DenyCrowdSecBan
|
||||
}
|
||||
}
|
||||
|
||||
// HasRestrictions returns true if any restriction rules are configured.
|
||||
func (f *Filter) HasRestrictions() bool {
|
||||
if f == nil {
|
||||
return false
|
||||
}
|
||||
return len(f.AllowedCIDRs) > 0 || len(f.BlockedCIDRs) > 0 ||
|
||||
len(f.AllowedCountries) > 0 || len(f.BlockedCountries) > 0
|
||||
len(f.AllowedCountries) > 0 || len(f.BlockedCountries) > 0 ||
|
||||
f.CrowdSecMode == CrowdSecEnforce || f.CrowdSecMode == CrowdSecObserve
|
||||
}
|
||||
|
||||
@@ -29,21 +29,21 @@ func TestFilter_Check_NilFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFilter_Check_AllowedCIDR(t *testing.T) {
|
||||
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}})
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_Check_BlockedCIDR(t *testing.T) {
|
||||
f := ParseFilter(nil, []string{"10.0.0.0/8"}, nil, nil)
|
||||
f := ParseFilter(FilterConfig{BlockedCIDRs: []string{"10.0.0.0/8"}})
|
||||
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("192.168.1.1"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_Check_AllowedAndBlockedCIDR(t *testing.T) {
|
||||
f := ParseFilter([]string{"10.0.0.0/8"}, []string{"10.1.0.0/16"}, nil, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, BlockedCIDRs: []string{"10.1.0.0/16"}})
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), nil), "allowed by allowlist, not in blocklist")
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "allowed by allowlist but in blocklist")
|
||||
@@ -56,7 +56,7 @@ func TestFilter_Check_AllowedCountry(t *testing.T) {
|
||||
"2.2.2.2": "DE",
|
||||
"3.3.3.3": "CN",
|
||||
})
|
||||
f := ParseFilter(nil, nil, []string{"US", "DE"}, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCountries: []string{"US", "DE"}})
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US in allowlist")
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE in allowlist")
|
||||
@@ -69,7 +69,7 @@ func TestFilter_Check_BlockedCountry(t *testing.T) {
|
||||
"2.2.2.2": "RU",
|
||||
"3.3.3.3": "US",
|
||||
})
|
||||
f := ParseFilter(nil, nil, nil, []string{"CN", "RU"})
|
||||
f := ParseFilter(FilterConfig{BlockedCountries: []string{"CN", "RU"}})
|
||||
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "CN in blocklist")
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "RU in blocklist")
|
||||
@@ -83,7 +83,7 @@ func TestFilter_Check_AllowedAndBlockedCountry(t *testing.T) {
|
||||
"3.3.3.3": "CN",
|
||||
})
|
||||
// Allow US and DE, but block DE explicitly.
|
||||
f := ParseFilter(nil, nil, []string{"US", "DE"}, []string{"DE"})
|
||||
f := ParseFilter(FilterConfig{AllowedCountries: []string{"US", "DE"}, BlockedCountries: []string{"DE"}})
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "US allowed and not blocked")
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("2.2.2.2"), geo), "DE allowed but also blocked, block wins")
|
||||
@@ -94,7 +94,7 @@ func TestFilter_Check_UnknownCountryWithAllowlist(t *testing.T) {
|
||||
geo := newMockGeo(map[string]string{
|
||||
"1.1.1.1": "US",
|
||||
})
|
||||
f := ParseFilter(nil, nil, []string{"US"}, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCountries: []string{"US"}})
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known US in allowlist")
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country denied when allowlist is active")
|
||||
@@ -104,34 +104,34 @@ func TestFilter_Check_UnknownCountryWithBlocklistOnly(t *testing.T) {
|
||||
geo := newMockGeo(map[string]string{
|
||||
"1.1.1.1": "CN",
|
||||
})
|
||||
f := ParseFilter(nil, nil, nil, []string{"CN"})
|
||||
f := ParseFilter(FilterConfig{BlockedCountries: []string{"CN"}})
|
||||
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("1.1.1.1"), geo), "known CN in blocklist")
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("9.9.9.9"), geo), "unknown country allowed when only blocklist is active")
|
||||
}
|
||||
|
||||
func TestFilter_Check_CountryWithoutGeo(t *testing.T) {
|
||||
f := ParseFilter(nil, nil, []string{"US"}, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCountries: []string{"US"}})
|
||||
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country allowlist")
|
||||
}
|
||||
|
||||
func TestFilter_Check_CountryBlocklistWithoutGeo(t *testing.T) {
|
||||
f := ParseFilter(nil, nil, nil, []string{"CN"})
|
||||
f := ParseFilter(FilterConfig{BlockedCountries: []string{"CN"}})
|
||||
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil), "nil geo with country blocklist")
|
||||
}
|
||||
|
||||
func TestFilter_Check_GeoUnavailable(t *testing.T) {
|
||||
geo := &unavailableGeo{}
|
||||
|
||||
f := ParseFilter(nil, nil, []string{"US"}, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCountries: []string{"US"}})
|
||||
assert.Equal(t, DenyGeoUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country allowlist")
|
||||
|
||||
f2 := ParseFilter(nil, nil, nil, []string{"CN"})
|
||||
f2 := ParseFilter(FilterConfig{BlockedCountries: []string{"CN"}})
|
||||
assert.Equal(t, DenyGeoUnavailable, f2.Check(netip.MustParseAddr("1.2.3.4"), geo), "unavailable geo with country blocklist")
|
||||
}
|
||||
|
||||
func TestFilter_Check_CIDROnlySkipsGeo(t *testing.T) {
|
||||
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}})
|
||||
|
||||
// CIDR-only filter should never touch geo, so nil geo is fine.
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil))
|
||||
@@ -143,7 +143,7 @@ func TestFilter_Check_CIDRAllowThenCountryBlock(t *testing.T) {
|
||||
"10.1.2.3": "CN",
|
||||
"10.2.3.4": "US",
|
||||
})
|
||||
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, []string{"CN"})
|
||||
f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, BlockedCountries: []string{"CN"}})
|
||||
|
||||
assert.Equal(t, DenyCountry, f.Check(netip.MustParseAddr("10.1.2.3"), geo), "CIDR allowed but country blocked")
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.2.3.4"), geo), "CIDR allowed and country not blocked")
|
||||
@@ -151,12 +151,12 @@ func TestFilter_Check_CIDRAllowThenCountryBlock(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseFilter_Empty(t *testing.T) {
|
||||
f := ParseFilter(nil, nil, nil, nil)
|
||||
f := ParseFilter(FilterConfig{})
|
||||
assert.Nil(t, f)
|
||||
}
|
||||
|
||||
func TestParseFilter_InvalidCIDR(t *testing.T) {
|
||||
f := ParseFilter([]string{"invalid", "10.0.0.0/8"}, nil, nil, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"invalid", "10.0.0.0/8"}})
|
||||
|
||||
assert.NotNil(t, f)
|
||||
assert.Len(t, f.AllowedCIDRs, 1, "invalid CIDR should be skipped")
|
||||
@@ -166,12 +166,12 @@ func TestParseFilter_InvalidCIDR(t *testing.T) {
|
||||
func TestFilter_HasRestrictions(t *testing.T) {
|
||||
assert.False(t, (*Filter)(nil).HasRestrictions())
|
||||
assert.False(t, (&Filter{}).HasRestrictions())
|
||||
assert.True(t, ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil).HasRestrictions())
|
||||
assert.True(t, ParseFilter(nil, nil, []string{"US"}, nil).HasRestrictions())
|
||||
assert.True(t, ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}).HasRestrictions())
|
||||
assert.True(t, ParseFilter(FilterConfig{AllowedCountries: []string{"US"}}).HasRestrictions())
|
||||
}
|
||||
|
||||
func TestFilter_Check_IPv6CIDR(t *testing.T) {
|
||||
f := ParseFilter([]string{"2001:db8::/32"}, nil, nil, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"2001:db8::/32"}})
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 addr in v6 allowlist")
|
||||
assert.Equal(t, DenyCIDR, f.Check(netip.MustParseAddr("2001:db9::1"), nil), "v6 addr not in v6 allowlist")
|
||||
@@ -179,7 +179,7 @@ func TestFilter_Check_IPv6CIDR(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFilter_Check_IPv4MappedIPv6(t *testing.T) {
|
||||
f := ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}})
|
||||
|
||||
// A v4-mapped-v6 address like ::ffff:10.1.2.3 must match a v4 CIDR.
|
||||
v4mapped := netip.MustParseAddr("::ffff:10.1.2.3")
|
||||
@@ -191,7 +191,7 @@ func TestFilter_Check_IPv4MappedIPv6(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFilter_Check_MixedV4V6CIDRs(t *testing.T) {
|
||||
f := ParseFilter([]string{"10.0.0.0/8", "2001:db8::/32"}, nil, nil, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8", "2001:db8::/32"}})
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("10.1.2.3"), nil), "v4 in v4 CIDR")
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("2001:db8::1"), nil), "v6 in v6 CIDR")
|
||||
@@ -202,7 +202,7 @@ func TestFilter_Check_MixedV4V6CIDRs(t *testing.T) {
|
||||
|
||||
func TestParseFilter_CanonicalizesNonMaskedCIDR(t *testing.T) {
|
||||
// 1.1.1.1/24 has host bits set; ParseFilter should canonicalize to 1.1.1.0/24.
|
||||
f := ParseFilter([]string{"1.1.1.1/24"}, nil, nil, nil)
|
||||
f := ParseFilter(FilterConfig{AllowedCIDRs: []string{"1.1.1.1/24"}})
|
||||
assert.Equal(t, netip.MustParsePrefix("1.1.1.0/24"), f.AllowedCIDRs[0])
|
||||
|
||||
// Verify it still matches correctly.
|
||||
@@ -264,7 +264,7 @@ func TestFilter_Check_CountryCodeCaseInsensitive(t *testing.T) {
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
f := ParseFilter(nil, nil, tc.allowedCountries, tc.blockedCountries)
|
||||
f := ParseFilter(FilterConfig{AllowedCountries: tc.allowedCountries, BlockedCountries: tc.blockedCountries})
|
||||
got := f.Check(netip.MustParseAddr(tc.addr), geo)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
@@ -275,4 +275,252 @@ func TestFilter_Check_CountryCodeCaseInsensitive(t *testing.T) {
|
||||
type unavailableGeo struct{}
|
||||
|
||||
func (u *unavailableGeo) LookupAddr(_ netip.Addr) geolocation.Result { return geolocation.Result{} }
|
||||
func (u *unavailableGeo) Available() bool { return false }
|
||||
func (u *unavailableGeo) Available() bool { return false }
|
||||
|
||||
// mockCrowdSec is a test implementation of CrowdSecChecker.
|
||||
type mockCrowdSec struct {
|
||||
decisions map[string]*CrowdSecDecision
|
||||
ready bool
|
||||
}
|
||||
|
||||
func (m *mockCrowdSec) CheckIP(addr netip.Addr) *CrowdSecDecision {
|
||||
return m.decisions[addr.Unmap().String()]
|
||||
}
|
||||
|
||||
func (m *mockCrowdSec) Ready() bool { return m.ready }
|
||||
|
||||
func TestFilter_CrowdSec_Enforce_Ban(t *testing.T) {
|
||||
cs := &mockCrowdSec{
|
||||
decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionBan}},
|
||||
ready: true,
|
||||
}
|
||||
f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce})
|
||||
|
||||
assert.Equal(t, DenyCrowdSecBan, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("5.6.7.8"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_CrowdSec_Enforce_Captcha(t *testing.T) {
|
||||
cs := &mockCrowdSec{
|
||||
decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionCaptcha}},
|
||||
ready: true,
|
||||
}
|
||||
f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce})
|
||||
|
||||
assert.Equal(t, DenyCrowdSecCaptcha, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_CrowdSec_Enforce_Throttle(t *testing.T) {
|
||||
cs := &mockCrowdSec{
|
||||
decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionThrottle}},
|
||||
ready: true,
|
||||
}
|
||||
f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce})
|
||||
|
||||
assert.Equal(t, DenyCrowdSecThrottle, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_CrowdSec_Observe_DoesNotBlock(t *testing.T) {
|
||||
cs := &mockCrowdSec{
|
||||
decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionBan}},
|
||||
ready: true,
|
||||
}
|
||||
f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecObserve})
|
||||
|
||||
verdict := f.Check(netip.MustParseAddr("1.2.3.4"), nil)
|
||||
assert.Equal(t, DenyCrowdSecBan, verdict, "verdict should be ban")
|
||||
assert.True(t, f.IsObserveOnly(verdict), "should be observe-only")
|
||||
}
|
||||
|
||||
func TestFilter_CrowdSec_Enforce_NotReady(t *testing.T) {
|
||||
cs := &mockCrowdSec{ready: false}
|
||||
f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce})
|
||||
|
||||
assert.Equal(t, DenyCrowdSecUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_CrowdSec_Observe_NotReady_Allows(t *testing.T) {
|
||||
cs := &mockCrowdSec{ready: false}
|
||||
f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecObserve})
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_CrowdSec_Off(t *testing.T) {
|
||||
cs := &mockCrowdSec{
|
||||
decisions: map[string]*CrowdSecDecision{"1.2.3.4": {Type: DecisionBan}},
|
||||
ready: true,
|
||||
}
|
||||
f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecOff})
|
||||
|
||||
// CrowdSecOff means the filter is nil (no restrictions).
|
||||
assert.Nil(t, f)
|
||||
}
|
||||
|
||||
func TestFilter_IsObserveOnly(t *testing.T) {
|
||||
f := &Filter{CrowdSecMode: CrowdSecObserve}
|
||||
assert.True(t, f.IsObserveOnly(DenyCrowdSecBan))
|
||||
assert.True(t, f.IsObserveOnly(DenyCrowdSecCaptcha))
|
||||
assert.True(t, f.IsObserveOnly(DenyCrowdSecThrottle))
|
||||
assert.True(t, f.IsObserveOnly(DenyCrowdSecUnavailable))
|
||||
assert.False(t, f.IsObserveOnly(DenyCIDR))
|
||||
assert.False(t, f.IsObserveOnly(Allow))
|
||||
|
||||
f2 := &Filter{CrowdSecMode: CrowdSecEnforce}
|
||||
assert.False(t, f2.IsObserveOnly(DenyCrowdSecBan))
|
||||
}
|
||||
|
||||
// TestFilter_LayerInteraction exercises the evaluation order across all three
|
||||
// restriction layers: CIDR -> Country -> CrowdSec. Each layer can only further
|
||||
// restrict; no layer can relax a denial from an earlier layer.
|
||||
//
|
||||
// Layer order | Behavior
|
||||
// ---------------|-------------------------------------------------------
|
||||
// 1. CIDR | Allowlist narrows to specific ranges, blocklist removes
|
||||
// | specific ranges. Deny here → stop, CrowdSec never runs.
|
||||
// 2. Country | Allowlist/blocklist by geo. Deny here → stop.
|
||||
// 3. CrowdSec | IP reputation. Can block IPs that passed layers 1-2.
|
||||
// | Observe mode: verdict returned but caller doesn't block.
|
||||
func TestFilter_LayerInteraction(t *testing.T) {
|
||||
bannedIP := "10.1.2.3"
|
||||
cleanIP := "10.2.3.4"
|
||||
outsideIP := "192.168.1.1"
|
||||
|
||||
cs := &mockCrowdSec{
|
||||
decisions: map[string]*CrowdSecDecision{bannedIP: {Type: DecisionBan}},
|
||||
ready: true,
|
||||
}
|
||||
geo := newMockGeo(map[string]string{
|
||||
bannedIP: "US",
|
||||
cleanIP: "US",
|
||||
outsideIP: "CN",
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config FilterConfig
|
||||
addr string
|
||||
want Verdict
|
||||
}{
|
||||
// CIDR allowlist + CrowdSec enforce: CrowdSec blocks inside allowed range
|
||||
{
|
||||
name: "allowed CIDR + CrowdSec banned",
|
||||
config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce},
|
||||
addr: bannedIP,
|
||||
want: DenyCrowdSecBan,
|
||||
},
|
||||
{
|
||||
name: "allowed CIDR + CrowdSec clean",
|
||||
config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce},
|
||||
addr: cleanIP,
|
||||
want: Allow,
|
||||
},
|
||||
{
|
||||
name: "CIDR deny stops before CrowdSec",
|
||||
config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce},
|
||||
addr: outsideIP,
|
||||
want: DenyCIDR,
|
||||
},
|
||||
|
||||
// CIDR blocklist + CrowdSec enforce: blocklist blocks first, CrowdSec blocks remaining
|
||||
{
|
||||
name: "blocked CIDR stops before CrowdSec",
|
||||
config: FilterConfig{BlockedCIDRs: []string{"10.1.0.0/16"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce},
|
||||
addr: bannedIP,
|
||||
want: DenyCIDR,
|
||||
},
|
||||
{
|
||||
name: "not in blocklist + CrowdSec clean",
|
||||
config: FilterConfig{BlockedCIDRs: []string{"10.1.0.0/16"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce},
|
||||
addr: cleanIP,
|
||||
want: Allow,
|
||||
},
|
||||
|
||||
// Country allowlist + CrowdSec enforce
|
||||
{
|
||||
name: "allowed country + CrowdSec banned",
|
||||
config: FilterConfig{AllowedCountries: []string{"US"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce},
|
||||
addr: bannedIP,
|
||||
want: DenyCrowdSecBan,
|
||||
},
|
||||
{
|
||||
name: "country deny stops before CrowdSec",
|
||||
config: FilterConfig{AllowedCountries: []string{"US"}, CrowdSec: cs, CrowdSecMode: CrowdSecEnforce},
|
||||
addr: outsideIP,
|
||||
want: DenyCountry,
|
||||
},
|
||||
|
||||
// All three layers: CIDR allowlist + country blocklist + CrowdSec
|
||||
{
|
||||
name: "all layers: CIDR allow + country allow + CrowdSec ban",
|
||||
config: FilterConfig{
|
||||
AllowedCIDRs: []string{"10.0.0.0/8"},
|
||||
BlockedCountries: []string{"CN"},
|
||||
CrowdSec: cs,
|
||||
CrowdSecMode: CrowdSecEnforce,
|
||||
},
|
||||
addr: bannedIP, // 10.x (CIDR ok), US (country ok), banned (CrowdSec deny)
|
||||
want: DenyCrowdSecBan,
|
||||
},
|
||||
{
|
||||
name: "all layers: CIDR deny short-circuits everything",
|
||||
config: FilterConfig{
|
||||
AllowedCIDRs: []string{"10.0.0.0/8"},
|
||||
BlockedCountries: []string{"CN"},
|
||||
CrowdSec: cs,
|
||||
CrowdSecMode: CrowdSecEnforce,
|
||||
},
|
||||
addr: outsideIP, // 192.x (CIDR deny)
|
||||
want: DenyCIDR,
|
||||
},
|
||||
|
||||
// Observe mode: verdict returned but IsObserveOnly is true
|
||||
{
|
||||
name: "observe mode: CrowdSec banned inside allowed CIDR",
|
||||
config: FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}, CrowdSec: cs, CrowdSecMode: CrowdSecObserve},
|
||||
addr: bannedIP,
|
||||
want: DenyCrowdSecBan, // verdict is ban, caller checks IsObserveOnly
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
f := ParseFilter(tc.config)
|
||||
got := f.Check(netip.MustParseAddr(tc.addr), geo)
|
||||
assert.Equal(t, tc.want, got)
|
||||
|
||||
// Verify observe mode flag when applicable.
|
||||
if tc.config.CrowdSecMode == CrowdSecObserve && got.IsCrowdSec() {
|
||||
assert.True(t, f.IsObserveOnly(got), "observe mode verdict should be observe-only")
|
||||
}
|
||||
if tc.config.CrowdSecMode == CrowdSecEnforce && got.IsCrowdSec() {
|
||||
assert.False(t, f.IsObserveOnly(got), "enforce mode verdict should not be observe-only")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter_CrowdSec_Enforce_NilChecker(t *testing.T) {
|
||||
// LAPI not configured: checker is nil but mode is enforce. Must fail closed.
|
||||
f := ParseFilter(FilterConfig{CrowdSec: nil, CrowdSecMode: CrowdSecEnforce})
|
||||
|
||||
assert.Equal(t, DenyCrowdSecUnavailable, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_CrowdSec_Observe_NilChecker(t *testing.T) {
|
||||
// LAPI not configured: checker is nil but mode is observe. Must allow.
|
||||
f := ParseFilter(FilterConfig{CrowdSec: nil, CrowdSecMode: CrowdSecObserve})
|
||||
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_HasRestrictions_CrowdSec(t *testing.T) {
|
||||
cs := &mockCrowdSec{ready: true}
|
||||
f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce})
|
||||
assert.True(t, f.HasRestrictions())
|
||||
|
||||
// Enforce mode without checker (LAPI not configured): still has restrictions
|
||||
// because Check() will fail-closed with DenyCrowdSecUnavailable.
|
||||
f2 := ParseFilter(FilterConfig{CrowdSec: nil, CrowdSecMode: CrowdSecEnforce})
|
||||
assert.True(t, f2.HasRestrictions())
|
||||
}
|
||||
|
||||
@@ -479,9 +479,14 @@ func (r *Router) checkRestrictions(conn net.Conn, route Route) restrict.Verdict
|
||||
// On success (nil error), both conn and backend are closed by the relay.
|
||||
func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error {
|
||||
if verdict := r.checkRestrictions(conn, route); verdict != restrict.Allow {
|
||||
r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict)
|
||||
r.logL4Deny(route, conn, verdict)
|
||||
return errAccessRestricted
|
||||
if route.Filter != nil && route.Filter.IsObserveOnly(verdict) {
|
||||
r.logger.Debugf("CrowdSec observe: would block %s for %s (%s)", conn.RemoteAddr(), sni, verdict)
|
||||
r.logL4Deny(route, conn, verdict, true)
|
||||
} else {
|
||||
r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict)
|
||||
r.logL4Deny(route, conn, verdict, false)
|
||||
return errAccessRestricted
|
||||
}
|
||||
}
|
||||
|
||||
svcCtx, err := r.acquireRelay(ctx, route)
|
||||
@@ -610,7 +615,7 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
|
||||
}
|
||||
|
||||
// logL4Deny sends an access log entry for a denied connection.
|
||||
func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict) {
|
||||
func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict, observeOnly bool) {
|
||||
r.mu.RLock()
|
||||
al := r.accessLog
|
||||
r.mu.RUnlock()
|
||||
@@ -621,14 +626,22 @@ func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict)
|
||||
|
||||
sourceIP, _ := addrFromConn(conn)
|
||||
|
||||
al.LogL4(accesslog.L4Entry{
|
||||
entry := accesslog.L4Entry{
|
||||
AccountID: route.AccountID,
|
||||
ServiceID: route.ServiceID,
|
||||
Protocol: route.Protocol,
|
||||
Host: route.Domain,
|
||||
SourceIP: sourceIP,
|
||||
DenyReason: verdict.String(),
|
||||
})
|
||||
}
|
||||
if verdict.IsCrowdSec() {
|
||||
entry.Metadata = map[string]string{"crowdsec_verdict": verdict.String()}
|
||||
if observeOnly {
|
||||
entry.Metadata["crowdsec_mode"] = "observe"
|
||||
entry.DenyReason = ""
|
||||
}
|
||||
}
|
||||
al.LogL4(entry)
|
||||
}
|
||||
|
||||
// getOrCreateServiceCtxLocked returns the context for a service, creating one
|
||||
|
||||
@@ -1686,7 +1686,7 @@ func (f *fakeConn) RemoteAddr() net.Addr { return f.remote }
|
||||
|
||||
func TestCheckRestrictions_UnparseableAddress(t *testing.T) {
|
||||
router := NewPortRouter(log.StandardLogger(), nil)
|
||||
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}})
|
||||
route := Route{Filter: filter}
|
||||
|
||||
conn := &fakeConn{remote: fakeAddr("not-an-ip")}
|
||||
@@ -1695,7 +1695,7 @@ func TestCheckRestrictions_UnparseableAddress(t *testing.T) {
|
||||
|
||||
func TestCheckRestrictions_NilRemoteAddr(t *testing.T) {
|
||||
router := NewPortRouter(log.StandardLogger(), nil)
|
||||
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}})
|
||||
route := Route{Filter: filter}
|
||||
|
||||
conn := &fakeConn{remote: nil}
|
||||
@@ -1704,7 +1704,7 @@ func TestCheckRestrictions_NilRemoteAddr(t *testing.T) {
|
||||
|
||||
func TestCheckRestrictions_AllowedAndDenied(t *testing.T) {
|
||||
router := NewPortRouter(log.StandardLogger(), nil)
|
||||
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}})
|
||||
route := Route{Filter: filter}
|
||||
|
||||
allowed := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 1234}}
|
||||
@@ -1724,7 +1724,7 @@ func TestCheckRestrictions_NilFilter(t *testing.T) {
|
||||
|
||||
func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) {
|
||||
router := NewPortRouter(log.StandardLogger(), nil)
|
||||
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
|
||||
filter := restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}})
|
||||
route := Route{Filter: filter}
|
||||
|
||||
// net.IPv4() returns a 16-byte v4-in-v6 representation internally.
|
||||
|
||||
@@ -336,8 +336,13 @@ func (r *Relay) checkAccessRestrictions(addr net.Addr) error {
|
||||
return fmt.Errorf("parse client address %s for restriction check: %w", addr, err)
|
||||
}
|
||||
if v := r.filter.Check(clientIP, r.geo); v != restrict.Allow {
|
||||
r.logDeny(clientIP, v)
|
||||
return fmt.Errorf("access restricted for %s", addr)
|
||||
if r.filter.IsObserveOnly(v) {
|
||||
r.logger.Debugf("CrowdSec observe: would block %s (%s)", clientIP, v)
|
||||
r.logDeny(clientIP, v, true)
|
||||
} else {
|
||||
r.logDeny(clientIP, v, false)
|
||||
return fmt.Errorf("access restricted for %s", addr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -498,19 +503,27 @@ func (r *Relay) logSessionEnd(sess *session) {
|
||||
}
|
||||
|
||||
// logDeny sends an access log entry for a denied UDP packet.
|
||||
func (r *Relay) logDeny(clientIP netip.Addr, verdict restrict.Verdict) {
|
||||
func (r *Relay) logDeny(clientIP netip.Addr, verdict restrict.Verdict, observeOnly bool) {
|
||||
if r.accessLog == nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.accessLog.LogL4(accesslog.L4Entry{
|
||||
entry := accesslog.L4Entry{
|
||||
AccountID: r.accountID,
|
||||
ServiceID: r.serviceID,
|
||||
Protocol: accesslog.ProtocolUDP,
|
||||
Host: r.domain,
|
||||
SourceIP: clientIP,
|
||||
DenyReason: verdict.String(),
|
||||
})
|
||||
}
|
||||
if verdict.IsCrowdSec() {
|
||||
entry.Metadata = map[string]string{"crowdsec_verdict": verdict.String()}
|
||||
if observeOnly {
|
||||
entry.Metadata["crowdsec_mode"] = "observe"
|
||||
entry.DenyReason = ""
|
||||
}
|
||||
}
|
||||
r.accessLog.LogL4(entry)
|
||||
}
|
||||
|
||||
// Close stops the relay, waits for all session goroutines to exit,
|
||||
|
||||
@@ -228,6 +228,10 @@ func (m *testProxyManager) ClusterRequireSubdomain(_ context.Context, _ string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||
"github.com/netbirdio/netbird/proxy/internal/conntrack"
|
||||
"github.com/netbirdio/netbird/proxy/internal/crowdsec"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
"github.com/netbirdio/netbird/proxy/internal/geolocation"
|
||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||
@@ -100,6 +101,13 @@ type Server struct {
|
||||
geo restrict.GeoResolver
|
||||
geoRaw *geolocation.Lookup
|
||||
|
||||
// crowdsecRegistry manages the shared CrowdSec bouncer lifecycle.
|
||||
crowdsecRegistry *crowdsec.Registry
|
||||
// crowdsecServices tracks which services have CrowdSec enabled for
|
||||
// proper acquire/release lifecycle management.
|
||||
crowdsecMu sync.Mutex
|
||||
crowdsecServices map[types.ServiceID]bool
|
||||
|
||||
// routerReady is closed once mainRouter is fully initialized.
|
||||
// The mapping worker waits on this before processing updates.
|
||||
routerReady chan struct{}
|
||||
@@ -175,6 +183,10 @@ type Server struct {
|
||||
// GeoDataDir is the directory containing GeoLite2 MMDB files for
|
||||
// country-based access restrictions. Empty disables geo lookups.
|
||||
GeoDataDir string
|
||||
// CrowdSecAPIURL is the CrowdSec LAPI URL. Empty disables CrowdSec.
|
||||
CrowdSecAPIURL string
|
||||
// CrowdSecAPIKey is the CrowdSec bouncer API key. Empty disables CrowdSec.
|
||||
CrowdSecAPIKey string
|
||||
// MaxSessionIdleTimeout caps the per-service session idle timeout.
|
||||
// Zero means no cap (the proxy honors whatever management sends).
|
||||
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
|
||||
@@ -275,6 +287,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
// management connectivity from the first stream connection.
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger))
|
||||
s.crowdsecServices = make(map[types.ServiceID]bool)
|
||||
|
||||
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
|
||||
|
||||
tlsConfig, err := s.configureTLS(ctx)
|
||||
@@ -763,6 +778,22 @@ func (s *Server) shutdownServices() {
|
||||
s.Logger.Debugf("close geolocation: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.shutdownCrowdSec()
|
||||
}
|
||||
|
||||
func (s *Server) shutdownCrowdSec() {
|
||||
if s.crowdsecRegistry == nil {
|
||||
return
|
||||
}
|
||||
s.crowdsecMu.Lock()
|
||||
services := maps.Clone(s.crowdsecServices)
|
||||
maps.Clear(s.crowdsecServices)
|
||||
s.crowdsecMu.Unlock()
|
||||
|
||||
for svcID := range services {
|
||||
s.crowdsecRegistry.Release(svcID)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveDialFunc returns a DialContextFunc that dials through the
|
||||
@@ -916,6 +947,7 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
supportsCrowdSec := s.crowdsecRegistry.Available()
|
||||
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: s.ID,
|
||||
Version: s.Version,
|
||||
@@ -924,6 +956,7 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
Capabilities: &proto.ProxyCapabilities{
|
||||
SupportsCustomPorts: &s.SupportsCustomPorts,
|
||||
RequireSubdomain: &s.RequireSubdomain,
|
||||
SupportsCrowdsec: &supportsCrowdSec,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1159,7 +1192,7 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin
|
||||
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
||||
Filter: parseRestrictions(mapping),
|
||||
Filter: s.parseRestrictions(mapping),
|
||||
})
|
||||
|
||||
s.portMu.Lock()
|
||||
@@ -1234,7 +1267,7 @@ func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMappin
|
||||
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
||||
Filter: parseRestrictions(mapping),
|
||||
Filter: s.parseRestrictions(mapping),
|
||||
})
|
||||
|
||||
if tlsPort != s.mainPort {
|
||||
@@ -1268,12 +1301,51 @@ func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.Ser
|
||||
|
||||
// parseRestrictions converts a proto mapping's access restrictions into
|
||||
// a restrict.Filter. Returns nil if the mapping has no restrictions.
|
||||
func parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter {
|
||||
func (s *Server) parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter {
|
||||
r := mapping.GetAccessRestrictions()
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return restrict.ParseFilter(r.GetAllowedCidrs(), r.GetBlockedCidrs(), r.GetAllowedCountries(), r.GetBlockedCountries())
|
||||
|
||||
svcID := types.ServiceID(mapping.GetId())
|
||||
csMode := restrict.CrowdSecMode(r.GetCrowdsecMode())
|
||||
|
||||
var checker restrict.CrowdSecChecker
|
||||
if csMode == restrict.CrowdSecEnforce || csMode == restrict.CrowdSecObserve {
|
||||
if b := s.crowdsecRegistry.Acquire(svcID); b != nil {
|
||||
checker = b
|
||||
s.crowdsecMu.Lock()
|
||||
s.crowdsecServices[svcID] = true
|
||||
s.crowdsecMu.Unlock()
|
||||
} else {
|
||||
s.Logger.Warnf("service %s requests CrowdSec mode %q but proxy has no CrowdSec configured", svcID, csMode)
|
||||
// Keep the mode: restrict.Filter will fail-closed for enforce (DenyCrowdSecUnavailable)
|
||||
// and allow for observe.
|
||||
}
|
||||
}
|
||||
|
||||
return restrict.ParseFilter(restrict.FilterConfig{
|
||||
AllowedCIDRs: r.GetAllowedCidrs(),
|
||||
BlockedCIDRs: r.GetBlockedCidrs(),
|
||||
AllowedCountries: r.GetAllowedCountries(),
|
||||
BlockedCountries: r.GetBlockedCountries(),
|
||||
CrowdSec: checker,
|
||||
CrowdSecMode: csMode,
|
||||
Logger: log.NewEntry(s.Logger),
|
||||
})
|
||||
}
|
||||
|
||||
// releaseCrowdSec releases the CrowdSec bouncer reference for the given
|
||||
// service if it had one.
|
||||
func (s *Server) releaseCrowdSec(svcID types.ServiceID) {
|
||||
s.crowdsecMu.Lock()
|
||||
had := s.crowdsecServices[svcID]
|
||||
delete(s.crowdsecServices, svcID)
|
||||
s.crowdsecMu.Unlock()
|
||||
|
||||
if had {
|
||||
s.crowdsecRegistry.Release(svcID)
|
||||
}
|
||||
}
|
||||
|
||||
// warnIfGeoUnavailable logs a warning if the mapping has country restrictions
|
||||
@@ -1388,7 +1460,7 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
||||
AccessLog: s.accessLog,
|
||||
Filter: parseRestrictions(mapping),
|
||||
Filter: s.parseRestrictions(mapping),
|
||||
Geo: s.geo,
|
||||
})
|
||||
relay.SetObserver(s.meter)
|
||||
@@ -1425,7 +1497,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader()))
|
||||
}
|
||||
|
||||
ipRestrictions := parseRestrictions(mapping)
|
||||
ipRestrictions := s.parseRestrictions(mapping)
|
||||
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
||||
|
||||
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
|
||||
@@ -1507,6 +1579,9 @@ func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) {
|
||||
// UDP relay cleanup (idempotent).
|
||||
s.removeUDPRelay(svcID)
|
||||
|
||||
// Release CrowdSec after all routes are removed so the shared bouncer
|
||||
// isn't stopped while stale filters can still be reached by in-flight requests.
|
||||
s.releaseCrowdSec(svcID)
|
||||
}
|
||||
|
||||
// removeUDPRelay stops and removes a UDP relay by service ID.
|
||||
|
||||
Reference in New Issue
Block a user