mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-29 13:46:41 +00:00
Merge branch 'main' into proto-ipv6-overlay
This commit is contained in:
@@ -82,6 +82,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
|
||||
mockCaps := proxy.NewMockManager(ctrl)
|
||||
mockCaps.EXPECT().ClusterSupportsCustomPorts(gomock.Any(), testCluster).Return(customPortsSupported).AnyTimes()
|
||||
mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes()
|
||||
mockCaps.EXPECT().ClusterSupportsCrowdSec(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes()
|
||||
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -29,6 +31,13 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func testCacheStore(t *testing.T) cachestore.StoreInterface {
|
||||
t.Helper()
|
||||
s, err := nbcache.NewStore(context.Background(), 30*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
return s
|
||||
}
|
||||
|
||||
func TestInitializeServiceForCreate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
@@ -423,10 +432,8 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
|
||||
newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer {
|
||||
t.Helper()
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
@@ -705,10 +712,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
},
|
||||
}
|
||||
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
@@ -1131,10 +1136,8 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
mockPerms := permissions.NewMockManager(ctrl)
|
||||
mockAcct := account.NewMockManager(ctrl)
|
||||
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
|
||||
@@ -113,6 +113,7 @@ type AccessRestrictions struct {
|
||||
BlockedCIDRs []string `json:"blocked_cidrs,omitempty" gorm:"serializer:json"`
|
||||
AllowedCountries []string `json:"allowed_countries,omitempty" gorm:"serializer:json"`
|
||||
BlockedCountries []string `json:"blocked_countries,omitempty" gorm:"serializer:json"`
|
||||
CrowdSecMode string `json:"crowdsec_mode,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the AccessRestrictions.
|
||||
@@ -122,6 +123,7 @@ func (r AccessRestrictions) Copy() AccessRestrictions {
|
||||
BlockedCIDRs: slices.Clone(r.BlockedCIDRs),
|
||||
AllowedCountries: slices.Clone(r.AllowedCountries),
|
||||
BlockedCountries: slices.Clone(r.BlockedCountries),
|
||||
CrowdSecMode: r.CrowdSecMode,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -555,7 +557,11 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
|
||||
}
|
||||
|
||||
if req.AccessRestrictions != nil {
|
||||
s.Restrictions = restrictionsFromAPI(req.AccessRestrictions)
|
||||
restrictions, err := restrictionsFromAPI(req.AccessRestrictions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Restrictions = restrictions
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -631,9 +637,9 @@ func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig {
|
||||
return auth
|
||||
}
|
||||
|
||||
func restrictionsFromAPI(r *api.AccessRestrictions) AccessRestrictions {
|
||||
func restrictionsFromAPI(r *api.AccessRestrictions) (AccessRestrictions, error) {
|
||||
if r == nil {
|
||||
return AccessRestrictions{}
|
||||
return AccessRestrictions{}, nil
|
||||
}
|
||||
var res AccessRestrictions
|
||||
if r.AllowedCidrs != nil {
|
||||
@@ -648,11 +654,19 @@ func restrictionsFromAPI(r *api.AccessRestrictions) AccessRestrictions {
|
||||
if r.BlockedCountries != nil {
|
||||
res.BlockedCountries = *r.BlockedCountries
|
||||
}
|
||||
return res
|
||||
if r.CrowdsecMode != nil {
|
||||
if !r.CrowdsecMode.Valid() {
|
||||
return AccessRestrictions{}, fmt.Errorf("invalid crowdsec_mode %q", *r.CrowdsecMode)
|
||||
}
|
||||
res.CrowdSecMode = string(*r.CrowdsecMode)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func restrictionsToAPI(r AccessRestrictions) *api.AccessRestrictions {
|
||||
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
|
||||
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 &&
|
||||
len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 &&
|
||||
r.CrowdSecMode == "" {
|
||||
return nil
|
||||
}
|
||||
res := &api.AccessRestrictions{}
|
||||
@@ -668,11 +682,17 @@ func restrictionsToAPI(r AccessRestrictions) *api.AccessRestrictions {
|
||||
if len(r.BlockedCountries) > 0 {
|
||||
res.BlockedCountries = &r.BlockedCountries
|
||||
}
|
||||
if r.CrowdSecMode != "" {
|
||||
mode := api.AccessRestrictionsCrowdsecMode(r.CrowdSecMode)
|
||||
res.CrowdsecMode = &mode
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func restrictionsToProto(r AccessRestrictions) *proto.AccessRestrictions {
|
||||
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
|
||||
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 &&
|
||||
len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 &&
|
||||
r.CrowdSecMode == "" {
|
||||
return nil
|
||||
}
|
||||
return &proto.AccessRestrictions{
|
||||
@@ -680,6 +700,7 @@ func restrictionsToProto(r AccessRestrictions) *proto.AccessRestrictions {
|
||||
BlockedCidrs: r.BlockedCIDRs,
|
||||
AllowedCountries: r.AllowedCountries,
|
||||
BlockedCountries: r.BlockedCountries,
|
||||
CrowdsecMode: r.CrowdSecMode,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -988,7 +1009,20 @@ const (
|
||||
|
||||
// validateAccessRestrictions validates and normalizes access restriction
|
||||
// entries. Country codes are uppercased in place.
|
||||
func validateCrowdSecMode(mode string) error {
|
||||
switch mode {
|
||||
case "", "off", "enforce", "observe":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("crowdsec_mode %q is invalid", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func validateAccessRestrictions(r *AccessRestrictions) error {
|
||||
if err := validateCrowdSecMode(r.CrowdSecMode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(r.AllowedCIDRs) > maxCIDREntries {
|
||||
return fmt.Errorf("allowed_cidrs: exceeds maximum of %d entries", maxCIDREntries)
|
||||
}
|
||||
@@ -1002,35 +1036,37 @@ func validateAccessRestrictions(r *AccessRestrictions) error {
|
||||
return fmt.Errorf("blocked_countries: exceeds maximum of %d entries", maxCountryEntries)
|
||||
}
|
||||
|
||||
for i, raw := range r.AllowedCIDRs {
|
||||
if err := validateCIDRList("allowed_cidrs", r.AllowedCIDRs); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateCIDRList("blocked_cidrs", r.BlockedCIDRs); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := normalizeCountryList("allowed_countries", r.AllowedCountries); err != nil {
|
||||
return err
|
||||
}
|
||||
return normalizeCountryList("blocked_countries", r.BlockedCountries)
|
||||
}
|
||||
|
||||
func validateCIDRList(field string, cidrs []string) error {
|
||||
for i, raw := range cidrs {
|
||||
prefix, err := netip.ParsePrefix(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allowed_cidrs[%d]: %w", i, err)
|
||||
return fmt.Errorf("%s[%d]: %w", field, i, err)
|
||||
}
|
||||
if prefix != prefix.Masked() {
|
||||
return fmt.Errorf("allowed_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
|
||||
return fmt.Errorf("%s[%d]: %q has host bits set, use %s instead", field, i, raw, prefix.Masked())
|
||||
}
|
||||
}
|
||||
for i, raw := range r.BlockedCIDRs {
|
||||
prefix, err := netip.ParsePrefix(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("blocked_cidrs[%d]: %w", i, err)
|
||||
}
|
||||
if prefix != prefix.Masked() {
|
||||
return fmt.Errorf("blocked_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
|
||||
}
|
||||
}
|
||||
for i, code := range r.AllowedCountries {
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeCountryList(field string, codes []string) error {
|
||||
for i, code := range codes {
|
||||
if len(code) != 2 {
|
||||
return fmt.Errorf("allowed_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
|
||||
return fmt.Errorf("%s[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", field, i, code)
|
||||
}
|
||||
r.AllowedCountries[i] = strings.ToUpper(code)
|
||||
}
|
||||
for i, code := range r.BlockedCountries {
|
||||
if len(code) != 2 {
|
||||
return fmt.Errorf("blocked_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
|
||||
}
|
||||
r.BlockedCountries[i] = strings.ToUpper(code)
|
||||
codes[i] = strings.ToUpper(code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user