mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-07 00:19:56 +00:00
Compare commits
3 Commits
v0.72.2
...
refactor/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
37ce40ede6 | ||
|
|
318014552f | ||
|
|
78ed62535e |
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -29,10 +29,10 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash -x release_files/freebsd-port-diff.sh
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
|
||||
- name: Generate FreeBSD port issue body
|
||||
run: bash -x release_files/freebsd-port-issue-body.sh
|
||||
run: bash release_files/freebsd-port-issue-body.sh
|
||||
|
||||
- name: Check if diff was generated
|
||||
id: check_diff
|
||||
|
||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -65,7 +65,7 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 62914560 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -120,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.InstanceManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.AuthMiddleware())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -153,6 +153,20 @@ func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AuthMiddleware() mux.MiddlewareFunc {
|
||||
return Create(s, func() mux.MiddlewareFunc {
|
||||
m := middleware.NewAuthMiddleware(
|
||||
s.AuthManager(),
|
||||
s.AccountManager().GetAccountIDFromUserAuth,
|
||||
s.AccountManager().SyncUserJWTGroups,
|
||||
s.AccountManager().GetUserFromUserAuth,
|
||||
s.RateLimiter(),
|
||||
s.Metrics().GetMeter(),
|
||||
)
|
||||
return m.Handler
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
return Create(s, func() *grpc.Server {
|
||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
@@ -151,6 +152,16 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) InstanceManager() instance.Manager {
|
||||
return Create(s, func() instance.Manager {
|
||||
m, err := instance.NewManager(context.Background(), s.Store(), s.IdpManager())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create instance manager: %v", err)
|
||||
}
|
||||
return m
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
|
||||
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
||||
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
|
||||
@@ -229,6 +240,3 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -147,9 +145,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
response.NetbirdConfig = toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
@@ -363,7 +359,6 @@ func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSource
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
|
||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
||||
|
||||
@@ -978,7 +978,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
Mode: m.Mode,
|
||||
ListenPort: m.ListenPort,
|
||||
AccessRestrictions: m.AccessRestrictions,
|
||||
Private: m.Private,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// authTokenField is the only per-proxy field that shallowCloneMapping must NOT
|
||||
// copy from the source, since callers assign it individually after cloning.
|
||||
const authTokenField = "AuthToken"
|
||||
|
||||
// TestShallowCloneMapping_ClonesAllFields populates every exported field of
|
||||
// ProxyMapping with a non-zero value and verifies the clone carries each one
|
||||
// (except AuthToken). It uses reflection so adding a new field to ProxyMapping
|
||||
// without updating shallowCloneMapping fails this test.
|
||||
func TestShallowCloneMapping_ClonesAllFields(t *testing.T) {
|
||||
src := &proto.ProxyMapping{}
|
||||
populated := populateExportedFields(t, reflect.ValueOf(src).Elem())
|
||||
require.NotEmpty(t, populated, "ProxyMapping should expose fields to populate")
|
||||
|
||||
clone := shallowCloneMapping(src)
|
||||
require.NotNil(t, clone, "clone must not be nil")
|
||||
|
||||
srcVal := reflect.ValueOf(src).Elem()
|
||||
cloneVal := reflect.ValueOf(clone).Elem()
|
||||
|
||||
for _, name := range populated {
|
||||
srcField := srcVal.FieldByName(name).Interface()
|
||||
cloneField := cloneVal.FieldByName(name).Interface()
|
||||
|
||||
if name == authTokenField {
|
||||
assert.Zero(t, cloneField, "AuthToken must not be cloned; it is set per proxy after cloning")
|
||||
continue
|
||||
}
|
||||
|
||||
assert.Equal(t, srcField, cloneField, "field %s must be carried over by shallowCloneMapping", name)
|
||||
}
|
||||
}
|
||||
|
||||
// populateExportedFields sets a non-zero value on every settable exported field
|
||||
// of the struct and returns their names.
|
||||
func populateExportedFields(t *testing.T, v reflect.Value) []string {
|
||||
t.Helper()
|
||||
|
||||
var names []string
|
||||
typ := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
structField := typ.Field(i)
|
||||
|
||||
if structField.PkgPath != "" || !field.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
setNonZero(t, field, structField.Name)
|
||||
names = append(names, structField.Name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// setNonZero assigns a deterministic non-zero value based on the field kind.
|
||||
func setNonZero(t *testing.T, field reflect.Value, name string) {
|
||||
t.Helper()
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
field.SetString("non-zero-" + name)
|
||||
case reflect.Bool:
|
||||
field.SetBool(true)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
field.SetInt(7)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
field.SetUint(7)
|
||||
case reflect.Ptr:
|
||||
field.Set(reflect.New(field.Type().Elem()))
|
||||
case reflect.Slice:
|
||||
field.Set(reflect.MakeSlice(field.Type(), 1, 1))
|
||||
case reflect.Map:
|
||||
field.Set(reflect.MakeMapWithSize(field.Type(), 0))
|
||||
default:
|
||||
t.Fatalf("unhandled field kind %s for field %s; extend setNonZero", field.Kind(), name)
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -41,8 +40,6 @@ type TimeBasedAuthSecretsManager struct {
|
||||
turnHmacToken *auth.TimedHMAC
|
||||
relayHmacToken *authv2.Generator
|
||||
updateManager network_map.PeersUpdateManager
|
||||
settingsManager settings.Manager
|
||||
groupsManager groups.Manager
|
||||
turnCancelMap map[string]chan struct{}
|
||||
relayCancelMap map[string]chan struct{}
|
||||
wgKey wgtypes.Key
|
||||
@@ -62,8 +59,6 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
|
||||
relayCfg: relayCfg,
|
||||
turnCancelMap: make(map[string]chan struct{}),
|
||||
relayCancelMap: make(map[string]chan struct{}),
|
||||
settingsManager: settingsManager,
|
||||
groupsManager: groupsManager,
|
||||
wgKey: key,
|
||||
}
|
||||
|
||||
@@ -239,8 +234,6 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
|
||||
}
|
||||
}
|
||||
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
@@ -266,26 +259,9 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
|
||||
},
|
||||
}
|
||||
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeControlConfig,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
||||
extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
|
||||
}
|
||||
|
||||
peerGroups, err := m.groupsManager.GetPeerGroupIDs(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peer groups: %v", err)
|
||||
}
|
||||
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, peerGroups, update.NetbirdConfig, extraSettings)
|
||||
update.NetbirdConfig = extendedConfig
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
|
||||
@@ -20,7 +19,6 @@ import (
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
@@ -34,7 +32,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
|
||||
@@ -49,7 +46,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/users"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||
@@ -59,7 +55,7 @@ import (
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, instanceManager nbinstance.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, authMiddleware mux.MiddlewareFunc) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -80,32 +76,11 @@ func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager accou
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
|
||||
if rateLimiter == nil {
|
||||
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
|
||||
rateLimiter = middleware.NewAPIRateLimiter(nil)
|
||||
rateLimiter.SetEnabled(false)
|
||||
}
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
authManager,
|
||||
accountManager.GetAccountIDFromUserAuth,
|
||||
accountManager.SyncUserJWTGroups,
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimiter,
|
||||
appMetrics.GetMeter(),
|
||||
isValidChildAccount,
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||
|
||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
|
||||
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware)
|
||||
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
@@ -25,7 +26,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
|
||||
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
|
||||
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
|
||||
// jwtTokenCtxKey carries the parsed JWT token.
|
||||
type jwtTokenCtxKey struct{}
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
@@ -35,7 +37,6 @@ type AuthMiddleware struct {
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
rateLimiter *APIRateLimiter
|
||||
patUsageTracker *PATUsageTracker
|
||||
isValidChildAccount IsValidChildAccountFunc
|
||||
}
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
@@ -46,7 +47,6 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiter *APIRateLimiter,
|
||||
meter metric.Meter,
|
||||
isValidChildAccount IsValidChildAccountFunc,
|
||||
) *AuthMiddleware {
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
@@ -64,12 +64,18 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth: getUserFromUserAuth,
|
||||
rateLimiter: rateLimiter,
|
||||
patUsageTracker: patUsageTracker,
|
||||
isValidChildAccount: isValidChildAccount,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
|
||||
// Handler composes the full authentication chain by wrapping the given
|
||||
// handler with ValidationHandler followed by AccountAccessHandler.
|
||||
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
return m.ValidationHandler(m.AccountAccessHandler(h))
|
||||
}
|
||||
|
||||
// ValidationHandler authenticates the caller via JWT or PAT and stores the
|
||||
// resulting UserAuth in the request context. It performs no account-level work.
|
||||
func (m *AuthMiddleware) ValidationHandler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
|
||||
return
|
||||
@@ -86,14 +92,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
|
||||
if err := m.validateJWT(r, authHeader); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
case "token":
|
||||
if err := m.checkPATFromRequest(r, authHeader); err != nil {
|
||||
if err := m.validatePAT(r, authHeader); err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
if _, ok := status.FromError(err); !ok {
|
||||
@@ -110,66 +116,55 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||
// AccountAccessHandler runs post-validation access checks for JWT-authenticated
|
||||
// requests. PAT requests pass through unchanged.
|
||||
func (m *AuthMiddleware) AccountAccessHandler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if userAuth.IsPAT {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
validatedToken, _ := r.Context().Value(jwtTokenCtxKey{}).(*jwt.Token)
|
||||
|
||||
if err := m.applyAccountAccess(r, userAuth, validatedToken); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error applying JWT account access: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) validateJWT(r *http.Request, authHeaderParts []string) error {
|
||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(r.Context(), token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
}
|
||||
|
||||
// Email is now extracted in ToUserAuth (from claims or userinfo endpoint)
|
||||
// Available as userAuth.Email
|
||||
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if userAuth.AccountId != accountId {
|
||||
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
||||
userAuth.AccountId = accountId
|
||||
}
|
||||
|
||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
|
||||
}
|
||||
|
||||
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
*r = *r.WithContext(context.WithValue(r.Context(), jwtTokenCtxKey{}, validatedToken))
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
func (m *AuthMiddleware) validatePAT(r *http.Request, authHeaderParts []string) error {
|
||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
@@ -192,8 +187,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||
if err != nil {
|
||||
if err := m.authManager.MarkPATUsed(ctx, pat.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -205,11 +199,40 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
IsPAT: true,
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyAccountAccess executes account-level checks for an authenticated JWT
|
||||
// user: ensures the account exists, verifies access via JWT groups, syncs
|
||||
// groups, and fetches the user record.
|
||||
func (m *AuthMiddleware) applyAccountAccess(r *http.Request, userAuth auth.UserAuth, validatedToken *jwt.Token) error {
|
||||
ctx := r.Context()
|
||||
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if userAuth.AccountId != accountId {
|
||||
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
||||
userAuth.AccountId = accountId
|
||||
}
|
||||
|
||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.syncUserJWTGroups(ctx, userAuth); err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
|
||||
}
|
||||
|
||||
if _, err := m.getUserFromUserAuth(ctx, userAuth); err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// propagates ctx change to upstream middleware
|
||||
|
||||
@@ -211,7 +211,6 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@@ -271,7 +270,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -324,7 +322,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -368,7 +365,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -413,7 +409,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -478,7 +473,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -538,7 +532,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -594,7 +587,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -695,7 +687,6 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
http2 "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
@@ -136,8 +137,12 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
rateLimiter := middleware.NewAPIRateLimiter(nil)
|
||||
rateLimiter.SetEnabled(false)
|
||||
authMiddleware := middleware.NewAuthMiddleware(authManagerMock, am.GetAccountIDFromUserAuth, am.SyncUserJWTGroups, am.GetUserFromUserAuth, rateLimiter, metrics.GetMeter())
|
||||
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, authMiddleware.Handler)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -266,8 +271,12 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
rateLimiter := middleware.NewAPIRateLimiter(nil)
|
||||
rateLimiter.SetEnabled(false)
|
||||
authMiddleware := middleware.NewAuthMiddleware(authManagerMock, am.GetAccountIDFromUserAuth, am.SyncUserJWTGroups, am.GetUserFromUserAuth, rateLimiter, metrics.GetMeter())
|
||||
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, authMiddleware.Handler)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -1216,7 +1216,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
Preload("NetworkResources").
|
||||
Preload("Onboarding").
|
||||
Preload("Services.Targets").
|
||||
Preload("Domains").
|
||||
Take(&account, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||
@@ -1303,7 +1302,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 16)
|
||||
errChan := make(chan error, 12)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1404,17 +1403,6 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
account.Services = services
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
domains, err := s.ListCustomDomains(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.Domains = domains
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -23,63 +21,6 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// TestGetAccount_LoadsCustomDomains verifies GetAccount populates account.Domains.
|
||||
// SynthesizePrivateServiceZones depends on this relation to anchor a custom-domain
|
||||
// private service's DNS zone; without the preload the relation is empty and the
|
||||
// service is silently skipped, so a custom domain never resolves on clients.
|
||||
func TestGetAccount_LoadsCustomDomains(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
assertGetAccountLoadsCustomDomains(t, store)
|
||||
}
|
||||
|
||||
func TestPostgresql_GetAccount_LoadsCustomDomains(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
assertGetAccountLoadsCustomDomains(t, store)
|
||||
}
|
||||
|
||||
// assertGetAccountLoadsCustomDomains exercises both the gorm and pgx GetAccount
|
||||
// paths: it persists two custom domains and asserts the relation comes back
|
||||
// populated, which SynthesizePrivateServiceZones relies on.
|
||||
func assertGetAccountLoadsCustomDomains(t *testing.T, store Store) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := "acct-custom-domains"
|
||||
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", "")))
|
||||
|
||||
_, err := store.CreateCustomDomain(ctx, accountID, "example.com", "eu.proxy.netbird.io", true)
|
||||
require.NoError(t, err, "creating the first custom domain must succeed")
|
||||
_, err = store.CreateCustomDomain(ctx, accountID, "apps.acme.io", "us.proxy.netbird.io", false)
|
||||
require.NoError(t, err, "creating the second custom domain must succeed")
|
||||
|
||||
account, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, account.Domains, 2, "GetAccount must preload the account's custom domains")
|
||||
|
||||
byDomain := map[string]string{}
|
||||
for _, d := range account.Domains {
|
||||
require.NotNil(t, d)
|
||||
byDomain[d.Domain] = d.TargetCluster
|
||||
}
|
||||
assert.Equal(t, "eu.proxy.netbird.io", byDomain["example.com"], "custom domain must carry its target cluster")
|
||||
assert.Equal(t, "us.proxy.netbird.io", byDomain["apps.acme.io"], "custom domain must carry its target cluster")
|
||||
}
|
||||
|
||||
// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads
|
||||
// all fields and nested objects from the database, including deeply nested structures.
|
||||
func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) {
|
||||
|
||||
@@ -273,7 +273,7 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
}
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
zonesByApex := map[string]*nbdns.CustomZone{}
|
||||
zonesByCluster := map[string]*nbdns.CustomZone{}
|
||||
|
||||
for _, svc := range a.Services {
|
||||
if svc == nil || !svc.Enabled || !svc.Private {
|
||||
@@ -290,24 +290,19 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
continue
|
||||
}
|
||||
|
||||
serviceDomainZone := a.privateServiceDomainZone(svc)
|
||||
if serviceDomainZone == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
zone, exists := zonesByApex[serviceDomainZone]
|
||||
zone, exists := zonesByCluster[svc.ProxyCluster]
|
||||
if !exists {
|
||||
// NonAuthoritative makes this a match-only zone: queries for
|
||||
// names without an explicit record fall through to the
|
||||
// upstream resolver instead of returning NXDOMAIN. Without
|
||||
// it, adding a single private service would black-hole every
|
||||
// other name under the zone apex.
|
||||
// other name under the cluster apex.
|
||||
zone = &nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(serviceDomainZone),
|
||||
Domain: dns.Fqdn(svc.ProxyCluster),
|
||||
Records: []nbdns.SimpleRecord{},
|
||||
NonAuthoritative: true,
|
||||
}
|
||||
zonesByApex[serviceDomainZone] = zone
|
||||
zonesByCluster[svc.ProxyCluster] = zone
|
||||
}
|
||||
|
||||
emitted := 0
|
||||
@@ -345,8 +340,8 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]nbdns.CustomZone, 0, len(zonesByApex))
|
||||
for _, zone := range zonesByApex {
|
||||
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
|
||||
for _, zone := range zonesByCluster {
|
||||
if len(zone.Records) == 0 {
|
||||
continue
|
||||
}
|
||||
@@ -362,33 +357,6 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
|
||||
return out
|
||||
}
|
||||
|
||||
// privateServiceDomainZone returns the DNS zone name for the given private service domain by
|
||||
// looking at the proxy cluster domain then the custom domains.
|
||||
func (a *Account) privateServiceDomainZone(svc *service.Service) string {
|
||||
if domainFromSuffix(svc.Domain, svc.ProxyCluster) {
|
||||
return svc.ProxyCluster
|
||||
}
|
||||
|
||||
// Longest matching custom domain wins
|
||||
zoneName := ""
|
||||
for _, d := range a.Domains {
|
||||
if d == nil || d.TargetCluster != svc.ProxyCluster {
|
||||
continue
|
||||
}
|
||||
if domainFromSuffix(svc.Domain, d.Domain) && len(d.Domain) > len(zoneName) {
|
||||
zoneName = d.Domain
|
||||
}
|
||||
}
|
||||
return zoneName
|
||||
}
|
||||
|
||||
func domainFromSuffix(domain, suffix string) bool {
|
||||
if suffix == "" {
|
||||
return false
|
||||
}
|
||||
return domain == suffix || strings.HasSuffix(domain, "."+suffix)
|
||||
}
|
||||
|
||||
// peerInDistributionGroups reports whether any of the peer's groups
|
||||
// matches the service's bearer-auth distribution_groups.
|
||||
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
@@ -235,113 +234,6 @@ func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testi
|
||||
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomain_ZoneApexIsRegisteredDomain(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// A custom-domain service: Domain is the custom FQDN, ProxyCluster
|
||||
// is the cluster serving it, and account.Domains holds the registered
|
||||
// custom domain. The synth zone apex must be the registered domain,
|
||||
// not the cluster, or the client's match-only zone never intercepts
|
||||
// the query.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "custom-domain service must still produce one zone")
|
||||
zone := zones[0]
|
||||
assert.Equal(t, "example.com.", zone.Domain, "zone apex must be the registered custom domain, not the cluster or the service FQDN")
|
||||
assert.True(t, zone.NonAuthoritative, "synth zone must remain match-only")
|
||||
require.Len(t, zone.Records, 1, "custom-domain service yields one A record")
|
||||
rec := zone.Records[0]
|
||||
assert.Equal(t, "app.example.com.", rec.Name, "record name is the custom service FQDN")
|
||||
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomAndFreeDomain_SeparateZones(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
ID: "svc-2",
|
||||
AccountID: "acct-1",
|
||||
Name: "custom",
|
||||
Domain: "app.example.com",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
})
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 2, "a free-domain and a custom-domain service must not collapse into one zone")
|
||||
|
||||
free, ok := findCustomZone(zones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok, "free-domain service keeps the shared cluster-apex zone")
|
||||
require.Len(t, free.Records, 1, "cluster zone carries only the free-domain record")
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", free.Records[0].Name, "cluster zone record is the free-domain FQDN")
|
||||
|
||||
custom, ok := findCustomZone(zones, "example.com")
|
||||
require.True(t, ok, "custom-domain service gets its own zone at the registered custom domain apex")
|
||||
require.Len(t, custom.Records, 1, "custom zone carries only the custom-domain record")
|
||||
assert.Equal(t, "app.example.com.", custom.Records[0].Name, "custom zone record is the custom-domain FQDN")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_TwoServicesSameCustomDomain_OneZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
account.Services[0].Domain = "a.example.com"
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
ID: "svc-2",
|
||||
AccountID: "acct-1",
|
||||
Name: "bapp",
|
||||
Domain: "b.example.com",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
})
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "two services under the same registered custom domain must share one zone")
|
||||
assert.Equal(t, "example.com.", zones[0].Domain, "shared zone apex is the registered custom domain")
|
||||
require.Len(t, zones[0].Records, 2, "both services surface as records in the shared custom-domain zone")
|
||||
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
|
||||
assert.ElementsMatch(t, []string{"a.example.com.", "b.example.com."}, names, "both custom-domain service FQDNs must surface")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomainNotRegistered_NoZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// Service domain is outside the cluster and no account.Domains entry
|
||||
// covers it: there is no apex that would intercept the query, so the
|
||||
// service must be skipped rather than emit an unmatchable record.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "a custom-domain service with no registered domain apex must not produce a zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_CustomDomainClusterMismatch_NoZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
// The registered custom domain matches the service FQDN by suffix but
|
||||
// targets a different cluster than the service's ProxyCluster. It must
|
||||
// be ignored, leaving no apex to intercept the query — otherwise the
|
||||
// zone would point at this cluster's proxy peers under a domain owned
|
||||
// by a different cluster.
|
||||
account.Services[0].Domain = "app.example.com"
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "us.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "a custom domain targeting a different cluster must not anchor the service zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
@@ -362,72 +254,3 @@ func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing
|
||||
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
|
||||
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_MixedClusterCustomAndPublic(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
|
||||
}
|
||||
|
||||
privateService := func(id, domain string) *service.Service {
|
||||
return &service.Service{
|
||||
ID: id,
|
||||
AccountID: "acct-1",
|
||||
Name: id,
|
||||
Domain: domain,
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
}
|
||||
}
|
||||
publicService := func(id, domain string) *service.Service {
|
||||
s := privateService(id, domain)
|
||||
s.Private = false
|
||||
return s
|
||||
}
|
||||
|
||||
account.Services = []*service.Service{
|
||||
// 3 private services under the cluster suffix.
|
||||
privateService("cluster-1", "cluster1.eu.proxy.netbird.io"),
|
||||
privateService("cluster-2", "cluster2.eu.proxy.netbird.io"),
|
||||
privateService("cluster-3", "cluster3.eu.proxy.netbird.io"),
|
||||
// 4 private services under the custom domain suffix.
|
||||
privateService("custom-1", "custom1.example.com"),
|
||||
privateService("custom-2", "custom2.example.com"),
|
||||
privateService("custom-3", "custom3.example.com"),
|
||||
privateService("custom-4", "custom4.example.com"),
|
||||
// 2 public services, one per suffix, must not surface.
|
||||
publicService("public-cluster", "public.eu.proxy.netbird.io"),
|
||||
publicService("public-custom", "public.example.com"),
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 2, "one zone per apex: the cluster apex and the custom domain apex")
|
||||
|
||||
cluster, ok := findCustomZone(zones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok, "cluster-suffix services collapse into the cluster-apex zone")
|
||||
clusterNames := recordNames(cluster)
|
||||
assert.ElementsMatch(t,
|
||||
[]string{"cluster1.eu.proxy.netbird.io.", "cluster2.eu.proxy.netbird.io.", "cluster3.eu.proxy.netbird.io."},
|
||||
clusterNames,
|
||||
"only the 3 private cluster services surface in the cluster zone (public one excluded)")
|
||||
|
||||
custom, ok := findCustomZone(zones, "example.com")
|
||||
require.True(t, ok, "custom-suffix services collapse into the custom-domain-apex zone")
|
||||
customNames := recordNames(custom)
|
||||
assert.ElementsMatch(t,
|
||||
[]string{"custom1.example.com.", "custom2.example.com.", "custom3.example.com.", "custom4.example.com."},
|
||||
customNames,
|
||||
"only the 4 private custom services surface in the custom zone (public one excluded)")
|
||||
}
|
||||
|
||||
// recordNames returns the record names of a zone for order-independent assertions.
|
||||
func recordNames(zone nbdns.CustomZone) []string {
|
||||
names := make([]string, 0, len(zone.Records))
|
||||
for _, r := range zone.Records {
|
||||
names = append(names, r.Name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user