mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-21 08:09:55 +00:00
feat(private-service): expose NetBird-only services over tunnel peers
Adds a new "private" service mode for the reverse proxy: services reachable exclusively over the embedded WireGuard tunnel, gated by per-peer group membership instead of operator auth schemes. Wire contract - ProxyMapping.private (field 13): the proxy MUST call ValidateTunnelPeer and fail closed; operator schemes are bypassed. - ProxyCapabilities.private (4) + supports_private_service (5): capability gate. Management never streams private mappings to proxies that don't claim the capability; the broadcast path applies the same filter via filterMappingsForProxy. - ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer, checks the peer's groups against service.AccessGroups, and mints a session JWT on success. checkPeerGroupAccess fails closed when a private service has empty AccessGroups. - ValidateSession/ValidateTunnelPeer responses now carry peer_group_ids + peer_group_names so the proxy can authorise policy-aware middlewares without an extra management round-trip. - ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account inbound listener state surfaced to dashboards. - PathTargetOptions.direct_upstream (11): bypass the embedded NetBird client and dial the target via the proxy host's network stack for upstreams reachable without WireGuard. Data model - Service.Private (bool) + Service.AccessGroups ([]string, JSON- serialised). Validate() rejects bearer auth on private services. Copy() deep-copies AccessGroups. pgx getServices loads the columns. - DomainConfig.Private threaded into the proxy auth middleware. Request handler routes private services through forwardWithTunnelPeer and returns 403 on validation failure. - Account-level SynthesizePrivateServiceZones (synthetic DNS) and injectPrivateServicePolicies (synthetic ACL) gate on len(svc.AccessGroups) > 0. Proxy - /netbird proxy --private (embedded mode) flag; Config.Private in proxy/lifecycle.go. - Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS on the embedded NetBird client's WireGuard tunnel netstack. - proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache with single-flight de-duplication and per-account eviction. - Local peerstore short-circuit: when the inbound IP isn't in the account roster, deny fast without an RPC. - proxy/server.go reports SupportsPrivateService=true and redacts the full ProxyMapping JSON from info logs (auth_token + header-auth hashed values now only at debug level). Identity forwarding - ValidateSessionJWT returns user_id, email, method, groups, group_names. sessionkey.Claims carries Email + Groups + GroupNames so the proxy can stamp identity onto upstream requests without an extra management round-trip on every cookie-bearing request. - CapturedData carries userEmail / userGroups / userGroupNames; the proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the authenticated identity (strips client-supplied values first to prevent spoofing). - AccessLog.UserGroups: access-log enrichment captures the user's group memberships at write time so the dashboard can render group context without reverse-resolving stale memberships. OpenAPI/dashboard surface - ReverseProxyService gains private + access_groups; ReverseProxyCluster gains private + supports_private. ReverseProxyTarget target_type enum gains "cluster". ServiceTargetOptions gains direct_upstream. ProxyAccessLog gains user_groups.
This commit is contained in:
@@ -5,6 +5,7 @@ package peers
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
@@ -35,6 +36,14 @@ type Manager interface {
|
||||
SetAccountManager(accountManager account.Manager)
|
||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
||||
// GetPeerByTunnelIP looks up a peer in accountID by its WireGuard tunnel IP.
|
||||
// Returns nil with an error when no match exists. No permission check;
|
||||
// callers (the proxy's ValidateTunnelPeer RPC) are trusted server components.
|
||||
GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error)
|
||||
// GetPeerWithGroups returns the peer and the list of *types.Group it belongs
|
||||
// to. Used by the proxy's auth path to authorise a request by the calling
|
||||
// peer's group memberships.
|
||||
GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -99,6 +108,24 @@ func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string,
|
||||
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP delegates to the store's indexed lookup.
|
||||
func (m *managerImpl) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||
return m.store.GetPeerByIP(ctx, store.LockingStrengthNone, accountID, ip)
|
||||
}
|
||||
|
||||
// GetPeerWithGroups returns the peer plus its group memberships.
|
||||
func (m *managerImpl) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||
p, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
groups, err := m.store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return p, nil, err
|
||||
}
|
||||
return p, groups, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ package peers
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
account "github.com/netbirdio/netbird/management/server/account"
|
||||
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
types "github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
@@ -38,6 +40,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CreateProxyPeer mocks base method.
|
||||
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID, peerKey, cluster string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||
}
|
||||
|
||||
// DeletePeers mocks base method.
|
||||
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -97,6 +113,21 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP mocks base method.
|
||||
func (m *MockManager) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerByTunnelIP", ctx, accountID, ip)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP indicates an expected call of GetPeerByTunnelIP.
|
||||
func (mr *MockManagerMockRecorder) GetPeerByTunnelIP(ctx, accountID, ip interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerByTunnelIP", reflect.TypeOf((*MockManager)(nil).GetPeerByTunnelIP), ctx, accountID, ip)
|
||||
}
|
||||
|
||||
// GetPeerID mocks base method.
|
||||
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -112,6 +143,22 @@ func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
|
||||
}
|
||||
|
||||
// GetPeerWithGroups mocks base method.
|
||||
func (m *MockManager) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerWithGroups", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].([]*types.Group)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetPeerWithGroups indicates an expected call of GetPeerWithGroups.
|
||||
func (mr *MockManagerMockRecorder) GetPeerWithGroups(ctx, accountID, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerWithGroups", reflect.TypeOf((*MockManager)(nil).GetPeerWithGroups), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// GetPeersByGroupIDs mocks base method.
|
||||
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -162,17 +209,3 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||
}
|
||||
|
||||
// CreateProxyPeer mocks base method.
|
||||
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,10 @@ type AccessLogEntry struct {
|
||||
BytesDownload int64 `gorm:"index"`
|
||||
Protocol AccessLogProtocol `gorm:"index"`
|
||||
Metadata map[string]string `gorm:"serializer:json"`
|
||||
// UserGroups captures the group IDs the user (UserId) belonged to at
|
||||
// the time the entry was written, so the dashboard can render group
|
||||
// context without reverse-resolving stale memberships.
|
||||
UserGroups []string `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
||||
@@ -125,6 +129,12 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
|
||||
metadata = &a.Metadata
|
||||
}
|
||||
|
||||
var userGroups *[]string
|
||||
if len(a.UserGroups) > 0 {
|
||||
groups := append([]string(nil), a.UserGroups...)
|
||||
userGroups = &groups
|
||||
}
|
||||
|
||||
return &api.ProxyAccessLog{
|
||||
Id: a.ID,
|
||||
ServiceId: a.ServiceID,
|
||||
@@ -145,5 +155,6 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
|
||||
BytesDownload: a.BytesDownload,
|
||||
Protocol: protocol,
|
||||
Metadata: metadata,
|
||||
UserGroups: userGroups,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func TestSaveAccessLog_EnrichesUserGroups(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
user := &types.User{Id: "u1", AutoGroups: []string{"g1", "g2"}}
|
||||
mockStore.EXPECT().
|
||||
GetUserByUserID(gomock.Any(), store.LockingStrengthNone, "u1").
|
||||
Return(user, nil)
|
||||
|
||||
var captured *accesslogs.AccessLogEntry
|
||||
mockStore.EXPECT().
|
||||
CreateAccessLog(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, e *accesslogs.AccessLogEntry) error {
|
||||
captured = e
|
||||
return nil
|
||||
})
|
||||
|
||||
m := &managerImpl{store: mockStore}
|
||||
entry := &accesslogs.AccessLogEntry{AccountID: "acc-1", UserId: "u1"}
|
||||
require.NoError(t, m.SaveAccessLog(context.Background(), entry))
|
||||
|
||||
require.NotNil(t, captured, "CreateAccessLog must receive the entry")
|
||||
assert.Equal(t, []string{"g1", "g2"}, captured.UserGroups, "UserGroups should be hydrated from the user record")
|
||||
}
|
||||
@@ -47,6 +47,8 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
|
||||
}
|
||||
}
|
||||
|
||||
m.enrichUserGroups(ctx, logEntry)
|
||||
|
||||
if err := m.store.CreateAccessLog(ctx, logEntry); err != nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"service_id": logEntry.ServiceID,
|
||||
@@ -161,6 +163,25 @@ func (m *managerImpl) StopPeriodicCleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// enrichUserGroups attaches the user's auto-group memberships to the entry.
|
||||
// Best-effort: errors are logged at debug and never block the save.
|
||||
func (m *managerImpl) enrichUserGroups(ctx context.Context, logEntry *accesslogs.AccessLogEntry) {
|
||||
if logEntry.UserId == "" {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, logEntry.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("access log user-group enrichment skipped for user %s: %v", logEntry.UserId, err)
|
||||
return
|
||||
}
|
||||
if user == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logEntry.UserGroups = append([]string(nil), user.AutoGroups...)
|
||||
}
|
||||
|
||||
// resolveUserFilters converts user email/name filters to user ID filter
|
||||
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
|
||||
if filter.UserEmail == nil && filter.UserName == nil {
|
||||
|
||||
@@ -23,6 +23,8 @@ type Domain struct {
|
||||
// SupportsCrowdSec is populated at query time from proxy cluster capabilities.
|
||||
// Not persisted.
|
||||
SupportsCrowdSec *bool `gorm:"-"`
|
||||
// SupportsPrivate is populated at query time from proxy cluster capabilities. Not persisted.
|
||||
SupportsPrivate *bool `gorm:"-"`
|
||||
}
|
||||
|
||||
// EventMeta returns activity event metadata for a domain
|
||||
|
||||
@@ -49,6 +49,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
||||
SupportsCustomPorts: d.SupportsCustomPorts,
|
||||
RequireSubdomain: d.RequireSubdomain,
|
||||
SupportsCrowdsec: d.SupportsCrowdSec,
|
||||
SupportsPrivate: d.SupportsPrivate,
|
||||
}
|
||||
if d.TargetCluster != "" {
|
||||
resp.TargetCluster = &d.TargetCluster
|
||||
|
||||
@@ -35,6 +35,7 @@ type proxyManager interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -93,6 +94,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
|
||||
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
|
||||
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
|
||||
d.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, cluster)
|
||||
ret = append(ret, d)
|
||||
}
|
||||
|
||||
@@ -109,6 +111,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
if d.TargetCluster != "" {
|
||||
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
|
||||
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
|
||||
cd.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, d.TargetCluster)
|
||||
}
|
||||
// Custom domains never require a subdomain by default since
|
||||
// the account owns them and should be able to use the bare domain.
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
type mockProxyManager struct {
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
}
|
||||
|
||||
@@ -40,6 +40,10 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
@@ -151,4 +155,3 @@ func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ type Manager interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||
|
||||
@@ -21,6 +21,7 @@ type store interface {
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
@@ -137,6 +138,11 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
|
||||
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate reports whether any active proxy claims the private capability (nil = unreported).
|
||||
func (m Manager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
return m.store.GetClusterSupportsPrivate(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
@@ -178,4 +184,3 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, acco
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,16 +15,16 @@ import (
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
@@ -99,6 +99,9 @@ func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *boo
|
||||
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestManager(s store) *Manager {
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
|
||||
@@ -92,6 +92,20 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate mocks base method.
|
||||
func (m *MockManager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClusterSupportsPrivate", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate indicates an expected call of ClusterSupportsPrivate.
|
||||
func (mr *MockManagerMockRecorder) ClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsPrivate", reflect.TypeOf((*MockManager)(nil).ClusterSupportsPrivate), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -20,6 +20,11 @@ type Capabilities struct {
|
||||
RequireSubdomain *bool
|
||||
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
|
||||
SupportsCrowdsec *bool
|
||||
// Private indicates whether this proxy is embedded in a netbird client
|
||||
// and serves exclusively over the WireGuard tunnel (i.e. `netbird proxy`
|
||||
// rather than the standalone netbird-proxy binary). Surfaces upstream
|
||||
// so dashboards can distinguish per-peer / private clusters.
|
||||
Private *bool
|
||||
}
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
@@ -67,10 +72,10 @@ type Cluster struct {
|
||||
Type ClusterType
|
||||
Online bool
|
||||
ConnectedProxies int
|
||||
// Capability flags. *bool because nil means "no proxy reported a
|
||||
// capability for this cluster" — the dashboard renders these as
|
||||
// unknown rather than false.
|
||||
// *bool: nil = no proxy reported the capability; the dashboard renders that as unknown.
|
||||
SupportsCustomPorts *bool
|
||||
RequireSubdomain *bool
|
||||
SupportsCrowdSec *bool
|
||||
// Private is true when at least one connected proxy reported the embedded-`netbird proxy` capability.
|
||||
Private *bool
|
||||
}
|
||||
|
||||
@@ -204,6 +204,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdSec,
|
||||
Private: c.Private,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -208,6 +208,9 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
||||
target.Host = resource.Domain
|
||||
case service.TargetTypeSubnet:
|
||||
// For subnets we do not do any lookups on the resource
|
||||
case service.TargetTypeCluster:
|
||||
// Cluster targets carry the upstream address on target_id; the
|
||||
// proxy resolves the destination at request time.
|
||||
default:
|
||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||
}
|
||||
@@ -779,6 +782,8 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
||||
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
|
||||
return err
|
||||
}
|
||||
case service.TargetTypeCluster:
|
||||
// Cluster targets are addressed by target_id; no peer/resource lookup.
|
||||
default:
|
||||
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
|
||||
}
|
||||
@@ -962,12 +967,14 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
|
||||
return fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
for _, s := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1344,3 +1344,40 @@ func TestValidateSubdomainRequirement(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetReferences_ClusterTargetSkipsLookup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
accountID := "test-account"
|
||||
|
||||
// No peer or resource lookups must be issued for cluster targets.
|
||||
targets := []*rpservice.Target{
|
||||
{TargetId: "eu.proxy.netbird.io", TargetType: rpservice.TargetTypeCluster},
|
||||
}
|
||||
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets), "cluster target must validate without store lookups")
|
||||
}
|
||||
|
||||
func TestReplaceHostByLookup_SkipsClusterTarget(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
accountID := "test-account"
|
||||
|
||||
mgr := &Manager{store: mockStore}
|
||||
|
||||
svc := &rpservice.Service{
|
||||
ID: "svc-1",
|
||||
AccountID: accountID,
|
||||
Targets: []*rpservice.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: rpservice.TargetTypeCluster,
|
||||
Host: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, mgr.replaceHostByLookup(ctx, accountID, svc), "cluster target must not trigger peer/resource lookup")
|
||||
assert.Equal(t, "127.0.0.1", svc.Targets[0].Host, "operator-supplied host must be preserved for cluster target")
|
||||
}
|
||||
|
||||
@@ -45,10 +45,11 @@ const (
|
||||
StatusCertificateFailed Status = "certificate_failed"
|
||||
StatusError Status = "error"
|
||||
|
||||
TargetTypePeer TargetType = "peer"
|
||||
TargetTypeHost TargetType = "host"
|
||||
TargetTypeDomain TargetType = "domain"
|
||||
TargetTypeSubnet TargetType = "subnet"
|
||||
TargetTypePeer TargetType = "peer"
|
||||
TargetTypeHost TargetType = "host"
|
||||
TargetTypeDomain TargetType = "domain"
|
||||
TargetTypeSubnet TargetType = "subnet"
|
||||
TargetTypeCluster TargetType = "cluster"
|
||||
|
||||
SourcePermanent = "permanent"
|
||||
SourceEphemeral = "ephemeral"
|
||||
@@ -60,14 +61,27 @@ type TargetOptions struct {
|
||||
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
|
||||
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||
// DirectUpstream bypasses the proxy's embedded NetBird client and dials
|
||||
// the target via the proxy host's network stack. Useful for upstreams
|
||||
// reachable without WireGuard (public APIs, LAN services, localhost
|
||||
// sidecars). Default false.
|
||||
DirectUpstream bool `json:"direct_upstream,omitempty"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
ID uint `gorm:"primaryKey" json:"-"`
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
ID uint `gorm:"primaryKey" json:"-"`
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
// Host carries the upstream address. For TargetTypeSubnet it is the only
|
||||
// source — operator-supplied. For TargetTypePeer / TargetTypeHost /
|
||||
// TargetTypeDomain it is overwritten by replaceHostByLookup with the
|
||||
// resolved peer IP / resource address, *unless* Options.DirectUpstream
|
||||
// is true and the operator supplied a non-empty value — then the
|
||||
// operator value is preserved so they can dial the upstream via the
|
||||
// host's network stack at an address other than the WG tunnel IP
|
||||
// (e.g. a LAN IP, localhost sidecar, or DNS name).
|
||||
Host string `json:"host"`
|
||||
Port uint16 `gorm:"index:idx_target_port" json:"port"`
|
||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||
@@ -200,6 +214,10 @@ type Service struct {
|
||||
Mode string `gorm:"default:'http'"`
|
||||
ListenPort uint16
|
||||
PortAutoAssigned bool
|
||||
// Private marks the service as NetBird-only: auth via ValidateTunnelPeer against AccessGroups instead of SSO. HTTP-only.
|
||||
Private bool
|
||||
// AccessGroups is the group ID allowlist for inbound peers on private services. Mutually exclusive with bearer SSO.
|
||||
AccessGroups []string `json:"access_groups,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||
@@ -299,6 +317,12 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
Mode: &mode,
|
||||
ListenPort: &listenPort,
|
||||
PortAutoAssigned: &s.PortAutoAssigned,
|
||||
Private: &s.Private,
|
||||
}
|
||||
|
||||
if len(s.AccessGroups) > 0 {
|
||||
groups := append([]string(nil), s.AccessGroups...)
|
||||
resp.AccessGroups = &groups
|
||||
}
|
||||
|
||||
if s.ProxyCluster != "" {
|
||||
@@ -308,6 +332,7 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
return resp
|
||||
}
|
||||
|
||||
// ToProtoMapping converts the service into the wire format the proxy consumes.
|
||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
||||
pathMappings := s.buildPathMappings()
|
||||
|
||||
@@ -349,6 +374,7 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
Mode: s.Mode,
|
||||
ListenPort: int32(s.ListenPort), //nolint:gosec
|
||||
Private: s.Private,
|
||||
}
|
||||
|
||||
if r := restrictionsToProto(s.Restrictions); r != nil {
|
||||
@@ -455,7 +481,8 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
|
||||
}
|
||||
|
||||
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 &&
|
||||
opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||
return nil
|
||||
}
|
||||
apiOpts := &api.ServiceTargetOptions{}
|
||||
@@ -477,17 +504,22 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
if len(opts.CustomHeaders) > 0 {
|
||||
apiOpts.CustomHeaders = &opts.CustomHeaders
|
||||
}
|
||||
if opts.DirectUpstream {
|
||||
apiOpts.DirectUpstream = &opts.DirectUpstream
|
||||
}
|
||||
return apiOpts
|
||||
}
|
||||
|
||||
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
|
||||
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||
return nil
|
||||
}
|
||||
popts := &proto.PathTargetOptions{
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
DirectUpstream: opts.DirectUpstream,
|
||||
}
|
||||
if opts.RequestTimeout != 0 {
|
||||
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||
@@ -537,6 +569,9 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
|
||||
if o.CustomHeaders != nil {
|
||||
opts.CustomHeaders = *o.CustomHeaders
|
||||
}
|
||||
if o.DirectUpstream != nil {
|
||||
opts.DirectUpstream = *o.DirectUpstream
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
@@ -551,6 +586,14 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
|
||||
if req.ListenPort != nil {
|
||||
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
|
||||
}
|
||||
if req.Private != nil {
|
||||
s.Private = *req.Private
|
||||
}
|
||||
if req.AccessGroups != nil {
|
||||
s.AccessGroups = append([]string(nil), *req.AccessGroups...)
|
||||
} else {
|
||||
s.AccessGroups = nil
|
||||
}
|
||||
|
||||
targets, err := targetsFromAPI(accountID, req.Targets)
|
||||
if err != nil {
|
||||
@@ -740,6 +783,9 @@ func (s *Service) Validate() error {
|
||||
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.validatePrivateRequirements(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch s.Mode {
|
||||
case ModeHTTP:
|
||||
@@ -753,6 +799,23 @@ func (s *Service) Validate() error {
|
||||
}
|
||||
}
|
||||
|
||||
// validatePrivateRequirements enforces the private-service contract: HTTP mode, ≥1 access group, no bearer auth.
|
||||
func (s *Service) validatePrivateRequirements() error {
|
||||
if !s.Private {
|
||||
return nil
|
||||
}
|
||||
if s.Mode != "" && s.Mode != ModeHTTP {
|
||||
return fmt.Errorf("private services only support HTTP mode, got %q", s.Mode)
|
||||
}
|
||||
if len(s.AccessGroups) == 0 {
|
||||
return errors.New("private services require at least one access group")
|
||||
}
|
||||
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||
return errors.New("private services cannot enable bearer auth (SSO): NetBird-only access and SSO are mutually exclusive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateHTTPMode() error {
|
||||
if s.Domain == "" {
|
||||
return errors.New("service domain is required")
|
||||
@@ -799,11 +862,20 @@ func (s *Service) validateHTTPTargets() error {
|
||||
for i, target := range s.Targets {
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||
// host field will be ignored
|
||||
// Host is normally overwritten by replaceHostByLookup with the
|
||||
// resolved peer IP / resource address; operator-supplied values
|
||||
// are honored only when DirectUpstream is set. Validate the
|
||||
// override here so misconfigured hosts fail fast at API time.
|
||||
if err := validateDirectUpstreamHost(i, target); err != nil {
|
||||
return err
|
||||
}
|
||||
case TargetTypeSubnet:
|
||||
if target.Host == "" {
|
||||
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||
}
|
||||
case TargetTypeCluster:
|
||||
// target_id carries the cluster address; the proxy resolves
|
||||
// the upstream at request time. Host/port may be empty.
|
||||
default:
|
||||
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||
}
|
||||
@@ -821,25 +893,55 @@ func (s *Service) validateHTTPTargets() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateDirectUpstreamHost validates the operator-supplied Host on a
|
||||
// peer/host/domain target when DirectUpstream is set. Empty Host is
|
||||
// allowed — the lookup fills in the default peer IP / resource address.
|
||||
// Without DirectUpstream the Host value is silently overwritten by
|
||||
// replaceHostByLookup, so we don't validate it (preserves the historical
|
||||
// behaviour where APIs accepted any value and dropped it). Non-empty
|
||||
// Host with DirectUpstream must look like a hostname or IP and must
|
||||
// not carry a port (port lives on Target.Port).
|
||||
func validateDirectUpstreamHost(idx int, target *Target) error {
|
||||
if !target.Options.DirectUpstream {
|
||||
return nil
|
||||
}
|
||||
host := strings.TrimSpace(target.Host)
|
||||
if host == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.ContainsAny(host, " \t/") {
|
||||
return fmt.Errorf("target %d: host %q contains invalid characters", idx, host)
|
||||
}
|
||||
if _, _, err := net.SplitHostPort(host); err == nil {
|
||||
return fmt.Errorf("target %d: host %q must not include a port (set target.port instead)", idx, host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateL4Target(target *Target) error {
|
||||
// L4 services have a single target; per-target disable is meaningless
|
||||
// (use the service-level Enabled flag instead). Force it on so that
|
||||
// buildPathMappings always includes the target in the proto.
|
||||
target.Enabled = true
|
||||
|
||||
if target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
if target.TargetId == "" {
|
||||
return errors.New("target_id is required for L4 services")
|
||||
}
|
||||
if target.TargetType != TargetTypeCluster && target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||
// OK
|
||||
if err := validateDirectUpstreamHost(0, target); err != nil {
|
||||
return err
|
||||
}
|
||||
case TargetTypeSubnet:
|
||||
if target.Host == "" {
|
||||
return errors.New("target host is required for subnet targets")
|
||||
}
|
||||
case TargetTypeCluster:
|
||||
// target_id carries the cluster address; the proxy resolves
|
||||
// the upstream at request time.
|
||||
default:
|
||||
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
|
||||
}
|
||||
@@ -1174,6 +1276,11 @@ func (s *Service) Copy() *Service {
|
||||
}
|
||||
}
|
||||
|
||||
var accessGroups []string
|
||||
if len(s.AccessGroups) > 0 {
|
||||
accessGroups = append([]string(nil), s.AccessGroups...)
|
||||
}
|
||||
|
||||
return &Service{
|
||||
ID: s.ID,
|
||||
AccountID: s.AccountID,
|
||||
@@ -1195,6 +1302,8 @@ func (s *Service) Copy() *Service {
|
||||
Mode: s.Mode,
|
||||
ListenPort: s.ListenPort,
|
||||
PortAutoAssigned: s.PortAutoAssigned,
|
||||
Private: s.Private,
|
||||
AccessGroups: accessGroups,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -1116,3 +1117,148 @@ func TestValidate_HeaderAuths(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "exceeds maximum length")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidate_HTTPClusterTarget(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate(), "HTTP cluster target with target_id must validate without host or port")
|
||||
}
|
||||
|
||||
func TestValidate_HTTPClusterTarget_RequiresTargetId(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "empty target_id", "cluster target must reject empty target_id")
|
||||
}
|
||||
|
||||
func TestValidate_L4ClusterTarget(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Mode = ModeTCP
|
||||
rp.ListenPort = 9000
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "tcp",
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
|
||||
}
|
||||
|
||||
func TestService_Copy_RoundtripsPrivate(t *testing.T) {
|
||||
svc := validProxy()
|
||||
svc.Private = true
|
||||
svc.AccessGroups = []string{"grp-admins", "grp-ops"}
|
||||
cp := svc.Copy()
|
||||
require.NotNil(t, cp)
|
||||
assert.True(t, cp.Private)
|
||||
assert.Equal(t, []string{"grp-admins", "grp-ops"}, cp.AccessGroups)
|
||||
|
||||
cp.Private = false
|
||||
assert.True(t, svc.Private)
|
||||
|
||||
cp.AccessGroups[0] = "grp-other"
|
||||
assert.Equal(t, []string{"grp-admins", "grp-ops"}, svc.AccessGroups)
|
||||
}
|
||||
|
||||
func TestService_APIRoundtrip_Private(t *testing.T) {
|
||||
enabled := true
|
||||
private := true
|
||||
accessGroups := []string{"grp-admins"}
|
||||
targets := []api.ServiceTarget{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: api.ServiceTargetTargetType("cluster"),
|
||||
Protocol: "http",
|
||||
Port: 80,
|
||||
Enabled: true,
|
||||
}}
|
||||
req := &api.ServiceRequest{
|
||||
Name: "svc-private",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
Enabled: enabled,
|
||||
Private: &private,
|
||||
AccessGroups: &accessGroups,
|
||||
Targets: &targets,
|
||||
}
|
||||
|
||||
svc := &Service{}
|
||||
require.NoError(t, svc.FromAPIRequest(req, "acc-1"))
|
||||
assert.True(t, svc.Private)
|
||||
assert.Equal(t, []string{"grp-admins"}, svc.AccessGroups)
|
||||
|
||||
resp := svc.ToAPIResponse()
|
||||
require.NotNil(t, resp.Private)
|
||||
assert.True(t, *resp.Private)
|
||||
require.NotNil(t, resp.AccessGroups)
|
||||
assert.Equal(t, []string{"grp-admins"}, *resp.AccessGroups)
|
||||
}
|
||||
|
||||
func TestValidate_Private_RequiresAccessGroups(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "access group")
|
||||
}
|
||||
|
||||
func TestValidate_Private_RejectsBearerAuth(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Auth.BearerAuth = &BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"grp-sso"},
|
||||
}
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "mutually exclusive")
|
||||
}
|
||||
|
||||
func TestValidate_Private_AcceptsNonClusterTargets(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_Private_AcceptsClusterTargetWithAccessGroups(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_Private_RejectsNonHTTPMode(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Mode = ModeTCP
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "tcp",
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "HTTP")
|
||||
}
|
||||
|
||||
@@ -20,6 +20,20 @@ type KeyPair struct {
|
||||
type Claims struct {
|
||||
jwt.RegisteredClaims
|
||||
Method auth.Method `json:"method"`
|
||||
// Email is the calling user's email address. Carried so the
|
||||
// proxy can stamp identity on upstream requests (e.g.
|
||||
// x-litellm-end-user-id) without an extra management
|
||||
// round-trip on every cookie-bearing request.
|
||||
Email string `json:"email,omitempty"`
|
||||
// Groups carries the user's group IDs so the proxy can stamp them
|
||||
// onto upstream requests (X-NetBird-Groups) from the cookie path
|
||||
// without an extra management round-trip.
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
// GroupNames carries the human-readable display names for the ids
|
||||
// in Groups, ordered identically (positional pairing). Slice may be
|
||||
// shorter than Groups for tokens minted before names were
|
||||
// resolvable; the consumer falls back to ids for missing positions.
|
||||
GroupNames []string `json:"group_names,omitempty"`
|
||||
}
|
||||
|
||||
func GenerateKeyPair() (*KeyPair, error) {
|
||||
@@ -34,7 +48,13 @@ func GenerateKeyPair() (*KeyPair, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
||||
// SignToken mints a session JWT for the given user and domain. email,
|
||||
// groups, and groupNames, when non-empty, are embedded so the proxy can
|
||||
// authorise and stamp identity for policy-aware middlewares without a
|
||||
// management round-trip on every cookie-bearing request. groupNames
|
||||
// pairs positionally with groups; pass nil when names couldn't be
|
||||
// resolved.
|
||||
func SignToken(privKeyB64, userID, email, domain string, method auth.Method, groups, groupNames []string, expiration time.Duration) (string, error) {
|
||||
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode private key: %w", err)
|
||||
@@ -56,7 +76,10 @@ func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
Method: method,
|
||||
Method: method,
|
||||
Email: email,
|
||||
Groups: append([]string(nil), groups...),
|
||||
GroupNames: append([]string(nil), groupNames...),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -79,7 +78,10 @@ type ProxyServiceServer struct {
|
||||
// Manager for peers
|
||||
peersManager peers.Manager
|
||||
|
||||
// Manager for users
|
||||
// Manager for users — also resolves user.AutoGroups to *types.Group
|
||||
// records via GetUserWithGroups, used to stamp human-readable
|
||||
// group names into JWT claims and onto upstream identity-bearing
|
||||
// requests (llm_identity_inject's tags header).
|
||||
usersManager users.Manager
|
||||
|
||||
// Store for one-time authentication tokens
|
||||
@@ -137,12 +139,9 @@ type proxyConnection struct {
|
||||
tokenID string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
// syncStream is set when the proxy connected via SyncMappings.
|
||||
// When non-nil, the sender goroutine uses this instead of stream.
|
||||
syncStream proto.ProxyService_SyncMappingsServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
@@ -210,322 +209,146 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
|
||||
s.proxyController = proxyController
|
||||
}
|
||||
|
||||
// proxyConnectParams holds the validated parameters extracted from either
|
||||
// a GetMappingUpdateRequest or a SyncMappingsInit message.
|
||||
type proxyConnectParams struct {
|
||||
proxyID string
|
||||
address string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
}
|
||||
|
||||
// GetMappingUpdate handles the control stream with proxy clients
|
||||
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
||||
params, err := s.validateProxyConnect(req.GetProxyId(), req.GetAddress(), stream.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.capabilities = req.GetCapabilities()
|
||||
ctx := stream.Context()
|
||||
|
||||
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||
stream: stream,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.sendSnapshot(stream.Context(), conn); err != nil {
|
||||
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
return s.serveProxyConnection(conn, proxyRecord, errChan, false)
|
||||
}
|
||||
|
||||
// SyncMappings implements the bidirectional SyncMappings RPC.
|
||||
// It mirrors GetMappingUpdate but provides application-level back-pressure:
|
||||
// management waits for an ack from the proxy before sending the next batch.
|
||||
func (s *ProxyServiceServer) SyncMappings(stream proto.ProxyService_SyncMappingsServer) error {
|
||||
init, err := recvSyncInit(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
params, err := s.validateProxyConnect(init.GetProxyId(), init.GetAddress(), stream.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.capabilities = init.GetCapabilities()
|
||||
|
||||
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||
syncStream: stream,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.sendSnapshotSync(stream.Context(), conn, stream); err != nil {
|
||||
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
go s.drainRecv(stream, errChan)
|
||||
|
||||
return s.serveProxyConnection(conn, proxyRecord, errChan, true)
|
||||
}
|
||||
|
||||
// recvSyncInit receives and validates the first message on a SyncMappings stream.
|
||||
func recvSyncInit(stream proto.ProxyService_SyncMappingsServer) (*proto.SyncMappingsInit, error) {
|
||||
firstMsg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "receive init: %v", err)
|
||||
}
|
||||
init := firstMsg.GetInit()
|
||||
if init == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "first message must be init")
|
||||
}
|
||||
return init, nil
|
||||
}
|
||||
|
||||
// validateProxyConnect validates the proxy ID and address, and checks cluster
|
||||
// address availability for account-scoped tokens.
|
||||
func (s *ProxyServiceServer) validateProxyConnect(proxyID, address string, ctx context.Context) (proxyConnectParams, error) {
|
||||
if proxyID == "" {
|
||||
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||
}
|
||||
if !isProxyAddressValid(address) {
|
||||
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, address, *token.AccountID)
|
||||
if err != nil {
|
||||
return proxyConnectParams{}, status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return proxyConnectParams{}, status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", address)
|
||||
}
|
||||
}
|
||||
|
||||
return proxyConnectParams{proxyID: proxyID, address: address}, nil
|
||||
}
|
||||
|
||||
// registerProxyConnection creates a proxyConnection, registers it with the
|
||||
// proxy manager and cluster, and stores it in connectedProxies. The caller
|
||||
// provides a partially initialised connSeed with stream-specific fields set;
|
||||
// the remaining fields are filled in here.
|
||||
func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params proxyConnectParams, connSeed *proxyConnection) (*proxyConnection, *proxy.Proxy, error) {
|
||||
peerInfo := PeerIPFromContext(ctx)
|
||||
log.Infof("New proxy connection from %s", peerInfo)
|
||||
|
||||
proxyID := req.GetProxyId()
|
||||
if proxyID == "" {
|
||||
return status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||
}
|
||||
|
||||
proxyAddress := req.GetAddress()
|
||||
if !isProxyAddressValid(proxyAddress) {
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
var accountID *string
|
||||
var tokenID string
|
||||
if token := GetProxyTokenFromContext(ctx); token != nil {
|
||||
if token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
|
||||
}
|
||||
}
|
||||
|
||||
var tokenID string
|
||||
if token != nil {
|
||||
tokenID = token.ID
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
s.supersedePriorConnection(params.proxyID, sessionID)
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
connSeed.proxyID = params.proxyID
|
||||
connSeed.sessionID = sessionID
|
||||
connSeed.address = params.address
|
||||
connSeed.accountID = accountID
|
||||
connSeed.tokenID = tokenID
|
||||
connSeed.capabilities = params.capabilities
|
||||
connSeed.sendChan = make(chan *proto.GetMappingUpdateResponse, 100)
|
||||
connSeed.ctx = connCtx
|
||||
connSeed.cancel = cancel
|
||||
|
||||
var caps *proxy.Capabilities
|
||||
if c := params.capabilities; c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
}
|
||||
}
|
||||
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, params.proxyID, sessionID, params.address, peerInfo, accountID, caps)
|
||||
if err != nil {
|
||||
cancel()
|
||||
if accountID != nil {
|
||||
return nil, nil, status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", params.proxyID, err)
|
||||
return nil, nil, status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(params.proxyID, connSeed)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, params.address, params.proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", params.proxyID, err)
|
||||
}
|
||||
|
||||
return connSeed, proxyRecord, nil
|
||||
}
|
||||
|
||||
// supersedePriorConnection cancels any existing connection for the given proxy.
|
||||
func (s *ProxyServiceServer) supersedePriorConnection(proxyID, newSessionID string) {
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": newSessionID,
|
||||
"new_session_id": sessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupFailedSnapshot removes the connection from the cluster and store
|
||||
// after a snapshot send failure.
|
||||
func (s *ProxyServiceServer) cleanupFailedSnapshot(ctx context.Context, conn *proxyConnection) {
|
||||
if s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
sessionID: sessionID,
|
||||
address: proxyAddress,
|
||||
accountID: accountID,
|
||||
tokenID: tokenID,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
ctx: connCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
Private: c.Private,
|
||||
}
|
||||
}
|
||||
conn.cancel()
|
||||
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// drainRecv consumes and discards messages from a bidirectional stream.
|
||||
// The proxy sends an ack for every incremental update; we don't need them
|
||||
// after the snapshot phase. Recv errors are forwarded to errChan.
|
||||
func (s *ProxyServiceServer) drainRecv(stream proto.ProxyService_SyncMappingsServer, errChan chan<- error) {
|
||||
for {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
|
||||
if err != nil {
|
||||
cancel()
|
||||
if accountID != nil {
|
||||
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// serveProxyConnection runs the post-snapshot lifecycle: heartbeat, sender,
|
||||
// and wait for termination. When bidi is true, normal stream closure (EOF,
|
||||
// canceled) is treated as a clean disconnect rather than an error.
|
||||
func (s *ProxyServiceServer) serveProxyConnection(conn *proxyConnection, proxyRecord *proxy.Proxy, errChan <-chan error, bidi bool) error {
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
||||
}
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": conn.proxyID,
|
||||
"session_id": conn.sessionID,
|
||||
"address": conn.address,
|
||||
"cluster_addr": conn.address,
|
||||
"account_id": conn.accountID,
|
||||
"proxy_id": proxyID,
|
||||
"session_id": sessionID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"account_id": accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
defer s.disconnectProxy(conn)
|
||||
go s.heartbeat(conn.ctx, conn, proxyRecord)
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
go s.heartbeat(connCtx, conn, proxyRecord)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if bidi && isStreamClosed(err) {
|
||||
log.Infof("Proxy %s stream closed", conn.proxyID)
|
||||
return nil
|
||||
}
|
||||
log.Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", conn.proxyID, err)
|
||||
case <-conn.ctx.Done():
|
||||
log.Infof("Proxy %s context canceled", conn.proxyID)
|
||||
return conn.ctx.Err()
|
||||
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||
case <-connCtx.Done():
|
||||
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
|
||||
return connCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// disconnectProxy removes the connection from cluster and store, unless it
|
||||
// has already been superseded by a newer connection.
|
||||
func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
|
||||
if !s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", conn.proxyID, conn.sessionID)
|
||||
conn.cancel()
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
|
||||
}
|
||||
|
||||
conn.cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
|
||||
}
|
||||
|
||||
// sendSnapshotSync sends the initial snapshot with back-pressure: it sends
|
||||
// one batch, then waits for the proxy to ack before sending the next.
|
||||
func (s *ProxyServiceServer) sendSnapshotSync(ctx context.Context, conn *proxyConnection, stream proto.ProxyService_SyncMappingsServer) error {
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
if s.snapshotBatchSize <= 0 {
|
||||
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||
}
|
||||
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||
end := i + s.snapshotBatchSize
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
for _, m := range mappings[i:end] {
|
||||
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||
}
|
||||
m.AuthToken = token
|
||||
}
|
||||
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot batch: %w", err)
|
||||
}
|
||||
|
||||
if err := waitForAck(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(mappings) == 0 {
|
||||
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||
InitialSyncComplete: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot completion: %w", err)
|
||||
}
|
||||
|
||||
if err := waitForAck(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitForAck(stream proto.ProxyService_SyncMappingsServer) error {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive ack: %w", err)
|
||||
}
|
||||
if msg.GetAck() == nil {
|
||||
return fmt.Errorf("expected ack, got %T", msg.GetMsg())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute and
|
||||
// disconnects the proxy if its access token has been revoked.
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
|
||||
@@ -562,9 +385,6 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
if s.snapshotBatchSize <= 0 {
|
||||
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||
}
|
||||
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
@@ -644,26 +464,12 @@ func isProxyAddressValid(addr string) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// isStreamClosed returns true for errors that indicate normal stream
|
||||
// termination: io.EOF, context cancellation, or gRPC Canceled.
|
||||
func isStreamClosed(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
return status.Code(err) == codes.Canceled
|
||||
}
|
||||
|
||||
// sender handles sending messages to proxy.
|
||||
// When conn.syncStream is set the message is sent as SyncMappingsResponse;
|
||||
// otherwise the legacy GetMappingUpdateResponse stream is used.
|
||||
// sender handles sending messages to proxy
|
||||
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
|
||||
for {
|
||||
select {
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.sendResponse(resp); err != nil {
|
||||
if err := conn.stream.Send(resp); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
@@ -673,17 +479,6 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponse sends a mapping update on whichever stream the proxy connected with.
|
||||
func (conn *proxyConnection) sendResponse(resp *proto.GetMappingUpdateResponse) error {
|
||||
if conn.syncStream != nil {
|
||||
return conn.syncStream.Send(&proto.SyncMappingsResponse{
|
||||
Mapping: resp.Mapping,
|
||||
InitialSyncComplete: resp.InitialSyncComplete,
|
||||
})
|
||||
}
|
||||
return conn.stream.Send(resp)
|
||||
}
|
||||
|
||||
// SendAccessLog processes access log from proxy
|
||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||
accessLog := req.GetLog()
|
||||
@@ -754,6 +549,11 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
// Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService).
|
||||
connUpdate = filterMappingsForProxy(conn, connUpdate)
|
||||
if len(connUpdate.Mapping) == 0 {
|
||||
return true
|
||||
}
|
||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||
if resp == nil {
|
||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
@@ -863,7 +663,7 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
continue
|
||||
}
|
||||
if !proxyAcceptsMapping(conn, update) {
|
||||
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
|
||||
log.WithContext(ctx).Debugf("Skipping proxy %s: missing capability for mapping %s", proxyID, update.Id)
|
||||
continue
|
||||
}
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
@@ -882,24 +682,41 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
}
|
||||
}
|
||||
|
||||
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
|
||||
// Old proxies that never reported capabilities are skipped for non-TLS L4
|
||||
// mappings with a custom listen port, since they don't understand the
|
||||
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
|
||||
// are new enough to handle the mapping. TLS uses SNI routing and works on
|
||||
// any proxy. Delete operations are always sent so proxies can clean up.
|
||||
// proxyAcceptsMapping returns whether the proxy can receive this mapping. Private requires SupportsPrivateService; custom-port L4 requires SupportsCustomPorts. Removes always pass.
|
||||
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
|
||||
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||
return true
|
||||
}
|
||||
if mapping.GetPrivate() {
|
||||
caps := conn.capabilities
|
||||
if caps == nil || caps.SupportsPrivateService == nil || !*caps.SupportsPrivateService {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
|
||||
return true
|
||||
}
|
||||
// Old proxies that never reported capabilities don't understand
|
||||
// custom port mappings.
|
||||
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
|
||||
}
|
||||
|
||||
// filterMappingsForProxy drops mappings the proxy can't safely receive.
|
||||
func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
||||
if update == nil {
|
||||
return update
|
||||
}
|
||||
filtered := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||
for _, m := range update.Mapping {
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
return &proto.GetMappingUpdateResponse{
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
@@ -945,6 +762,7 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
Mode: m.Mode,
|
||||
ListenPort: m.ListenPort,
|
||||
AccessRestrictions: m.AccessRestrictions,
|
||||
Private: m.Private,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -953,7 +771,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
service, err := s.lookupServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
|
||||
@@ -961,7 +779,10 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
|
||||
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
|
||||
|
||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method)
|
||||
// Non-OIDC schemes (PIN/Password/Header) authenticate against per-service
|
||||
// secrets and have no user-level group context, so groups stay nil. Email
|
||||
// is also empty — these schemes don't resolve a user record at sign time.
|
||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, "", method, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1050,7 +871,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId, userEmail string, method proxyauth.Method, groupIDs, groupNames []string) (string, error) {
|
||||
if !authenticated || service.SessionPrivateKey == "" {
|
||||
return "", nil
|
||||
}
|
||||
@@ -1058,8 +879,11 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
token, err := sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userId,
|
||||
userEmail,
|
||||
service.Domain,
|
||||
method,
|
||||
groupIDs,
|
||||
groupNames,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -1070,6 +894,29 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// pairGroupIDsAndNames splits a slice of resolved *types.Group records
|
||||
// into parallel id and name slices. Used to feed JWT claims and
|
||||
// response messages from a single GetUserWithGroups call. Position i
|
||||
// of the returned slices always pairs the same group: ids[i] is the
|
||||
// group id, names[i] is its display name. Records the manager skipped
|
||||
// (orphan ids in user.AutoGroups with no matching group row) are
|
||||
// already absent here, so the consumer can rely on positional pairing.
|
||||
func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) {
|
||||
if len(groups) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
ids = make([]string, 0, len(groups))
|
||||
names = make([]string, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
if g == nil {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, g.ID)
|
||||
names = append(names, g.Name)
|
||||
}
|
||||
return ids, names
|
||||
}
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients.
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
@@ -1216,19 +1063,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
||||
return nil, status.Errorf(codes.InvalidArgument, "redirect URL must use http or https scheme")
|
||||
}
|
||||
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
|
||||
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account services: %v", err)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
|
||||
}
|
||||
var found bool
|
||||
for _, service := range services {
|
||||
if service.Domain == redirectURL.Hostname() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if _, err := s.getAccountServiceByDomain(ctx, req.GetAccountId(), redirectURL.Hostname()); err != nil {
|
||||
log.WithContext(ctx).Debugf("OIDC redirect URL %q does not match any service domain", redirectURL.Hostname())
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "service not found in store")
|
||||
}
|
||||
@@ -1334,7 +1169,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
return verifier, redirectURL, nil
|
||||
}
|
||||
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and
|
||||
// user. The user's group memberships are embedded in the token so policy-aware
|
||||
// middlewares on the proxy can authorise without an extra management round-trip.
|
||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
@@ -1345,11 +1182,29 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
|
||||
return "", fmt.Errorf("no session key configured for domain: %s", domain)
|
||||
}
|
||||
|
||||
var (
|
||||
email string
|
||||
groupIDs []string
|
||||
groupNames []string
|
||||
)
|
||||
if s.usersManager != nil {
|
||||
user, userGroups, uerr := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||
if uerr != nil {
|
||||
log.WithContext(ctx).Debugf("session token mint: lookup user %s: %v", userID, uerr)
|
||||
} else if user != nil {
|
||||
email = user.Email
|
||||
groupIDs, groupNames = pairGroupIDsAndNames(userGroups)
|
||||
}
|
||||
}
|
||||
|
||||
return sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userID,
|
||||
email,
|
||||
domain,
|
||||
method,
|
||||
groupIDs,
|
||||
groupNames,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
}
|
||||
@@ -1411,6 +1266,11 @@ func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, acco
|
||||
return nil, fmt.Errorf("service not found for domain %s in account %s", domain, accountID)
|
||||
}
|
||||
|
||||
// lookupServiceByID resolves a service by ID via the persisted store.
|
||||
func (s *ProxyServiceServer) lookupServiceByID(ctx context.Context, accountID, serviceID string) (*rpservice.Service, error) {
|
||||
return s.serviceManager.GetServiceByID(ctx, accountID, serviceID)
|
||||
}
|
||||
|
||||
// ValidateSession validates a session token and checks if the user has access to the domain.
|
||||
func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.ValidateSessionRequest) (*proto.ValidateSessionResponse, error) {
|
||||
domain := req.GetDomain()
|
||||
@@ -1453,7 +1313,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
||||
userID, _, _, _, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
@@ -1466,7 +1326,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
user, err := s.usersManager.GetUser(ctx, userID)
|
||||
user, userGroups, err := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
@@ -1500,12 +1360,15 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
"user_id": userID,
|
||||
"error": err.Error(),
|
||||
}).Debug("ValidateSession: access denied")
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateSessionResponse{
|
||||
Valid: false,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
DeniedReason: "not_in_group",
|
||||
Valid: false,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
DeniedReason: "not_in_group",
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1515,10 +1378,13 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
"email": user.Email,
|
||||
}).Debug("ValidateSession: access granted")
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||
return &proto.ValidateSessionResponse{
|
||||
Valid: true,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
Valid: true,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1526,6 +1392,138 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
|
||||
return s.serviceManager.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
|
||||
// checks the peer's group membership against the service's access groups.
|
||||
// Peers without a user (machine agents, automation workloads) are
|
||||
// first-class callers, so authorisation runs off peer-group memberships
|
||||
// rather than the optional owning user's auto-groups.
|
||||
//
|
||||
// The minted session JWT carries peer.ID (Subject) and a display identity
|
||||
// (Email): user.Email when the peer is attached to a user, else peer.Name.
|
||||
// Downstream identity-stamping middlewares emit this display identity on
|
||||
// upstream requests, so per-human attribution lines up across a user's
|
||||
// devices and unattached peers still produce a stable, human-readable id.
|
||||
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
domain := req.GetDomain()
|
||||
tunnelIPStr := req.GetTunnelIp()
|
||||
|
||||
if domain == "" || tunnelIPStr == "" {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "missing domain or tunnel_ip",
|
||||
}, nil
|
||||
}
|
||||
|
||||
tunnelIP := net.ParseIP(tunnelIPStr)
|
||||
if tunnelIP == nil {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "invalid_tunnel_ip",
|
||||
}, nil
|
||||
}
|
||||
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
"error": err.Error(),
|
||||
}).Debug("ValidateTunnelPeer: service not found")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "service_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
|
||||
if err != nil || peer == nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
"tunnel_ip": tunnelIPStr,
|
||||
}).Debug("ValidateTunnelPeer: peer not found")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "peer_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
"peer_id": peer.ID,
|
||||
"error": err.Error(),
|
||||
}).Debug("ValidateTunnelPeer: peer groups lookup failed")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "peer_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||
|
||||
// Resolve the principal: when the peer is linked to a user, the
|
||||
// human is the principal — multiple peers belonging to the same
|
||||
// user share one access-log identity AND one consumption-counter
|
||||
// dimension (caps are per-human, not per-device). Unlinked peers
|
||||
// (machine agents, automation workloads) are their own principals
|
||||
// keyed on peer.ID. displayIdentity is what upstream gateways
|
||||
// (LiteLLM, Portkey) tag spend with — user.Email when linked, the
|
||||
// peer's hostname when not.
|
||||
principalID := peer.ID
|
||||
displayIdentity := peer.Name
|
||||
if peer.UserID != "" {
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
"peer_id": peer.ID,
|
||||
"error": err.Error(),
|
||||
}).Debug("ValidateTunnelPeer: access denied")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
UserId: principalID,
|
||||
UserEmail: displayIdentity,
|
||||
DeniedReason: "not_in_group",
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
"tunnel_ip": tunnelIPStr,
|
||||
"peer_id": peer.ID,
|
||||
"peer_name": peer.Name,
|
||||
"principal_id": principalID,
|
||||
"identity": displayIdentity,
|
||||
}).Debug("ValidateTunnelPeer: access granted")
|
||||
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: true,
|
||||
UserId: principalID,
|
||||
UserEmail: displayIdentity,
|
||||
SessionToken: token,
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
|
||||
return nil
|
||||
@@ -1550,4 +1548,36 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
|
||||
return fmt.Errorf("user not in allowed groups")
|
||||
}
|
||||
|
||||
// checkPeerGroupAccess gates ValidateTunnelPeer; private with empty AccessGroups fails closed.
|
||||
func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error {
|
||||
if service.Private {
|
||||
if len(service.AccessGroups) == 0 {
|
||||
return fmt.Errorf("private service has no access groups")
|
||||
}
|
||||
return matchAnyGroup(service.AccessGroups, peerGroupIDs)
|
||||
}
|
||||
|
||||
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
|
||||
if len(service.Auth.BearerAuth.DistributionGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func matchAnyGroup(allowedGroups, peerGroupIDs []string) error {
|
||||
allowedSet := make(map[string]bool, len(allowedGroups))
|
||||
for _, groupID := range allowedGroups {
|
||||
allowedSet[groupID] = true
|
||||
}
|
||||
for _, groupID := range peerGroupIDs {
|
||||
if allowedSet[groupID] {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("peer not in allowed groups")
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T { return &v }
|
||||
|
||||
@@ -129,6 +129,14 @@ func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.U
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||
user, err := m.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return user, nil, nil
|
||||
}
|
||||
|
||||
func TestValidateUserGroupAccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -420,3 +428,46 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPeerGroupAccess(t *testing.T) {
|
||||
t.Run("private with empty AccessGroups denies", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: nil}
|
||||
err := checkPeerGroupAccess(svc, []string{"grp-admins"})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no access groups")
|
||||
})
|
||||
|
||||
t.Run("private with peer in AccessGroups allows", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins", "grp-ops"}}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-other", "grp-ops"}))
|
||||
})
|
||||
|
||||
t.Run("private with peer outside AccessGroups denies", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins"}}
|
||||
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||
})
|
||||
|
||||
t.Run("bearer enabled with empty DistributionGroups allows", func(t *testing.T) {
|
||||
svc := &service.Service{
|
||||
Auth: service.AuthConfig{BearerAuth: &service.BearerAuthConfig{Enabled: true}},
|
||||
}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-anyone"}))
|
||||
})
|
||||
|
||||
t.Run("bearer enabled gates on DistributionGroups", func(t *testing.T) {
|
||||
svc := &service.Service{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"grp-allowed"},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-allowed"}))
|
||||
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||
})
|
||||
|
||||
t.Run("non-private non-bearer is open", func(t *testing.T) {
|
||||
assert.NoError(t, checkPeerGroupAccess(&service.Service{}, nil))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,411 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// syncRecordingStream is a mock ProxyService_SyncMappingsServer that records
|
||||
// sent messages and returns pre-loaded ack responses from Recv.
|
||||
type syncRecordingStream struct {
|
||||
grpc.ServerStream
|
||||
|
||||
mu sync.Mutex
|
||||
sent []*proto.SyncMappingsResponse
|
||||
recvMsgs []*proto.SyncMappingsRequest
|
||||
recvIdx int
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Send(m *proto.SyncMappingsResponse) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sent = append(s.sent, m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.recvIdx >= len(s.recvMsgs) {
|
||||
return nil, fmt.Errorf("no more recv messages")
|
||||
}
|
||||
msg := s.recvMsgs[s.recvIdx]
|
||||
s.recvIdx++
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Context() context.Context { return context.Background() }
|
||||
func (s *syncRecordingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *syncRecordingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *syncRecordingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *syncRecordingStream) SendMsg(any) error { return nil }
|
||||
func (s *syncRecordingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func ackMsg() *proto.SyncMappingsRequest {
|
||||
return &proto.SyncMappingsRequest{
|
||||
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_BatchesWithAcks(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 7 // 3 + 3 + 1 → 3 batches, 3 acks (one per batch, including final)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg(), ackMsg(), ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 3, "should send ceil(7/3) = 3 batches")
|
||||
|
||||
assert.Len(t, stream.sent[0].Mapping, 3)
|
||||
assert.False(t, stream.sent[0].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.sent[1].Mapping, 3)
|
||||
assert.False(t, stream.sent[1].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.sent[2].Mapping, 1)
|
||||
assert.True(t, stream.sent[2].InitialSyncComplete)
|
||||
|
||||
// All 3 acks consumed — including the final batch.
|
||||
assert.Equal(t, 3, stream.recvIdx)
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_SingleBatchWaitsForAck(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 100
|
||||
const totalServices = 5
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1)
|
||||
assert.Len(t, stream.sent[0].Mapping, totalServices)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
assert.Equal(t, 1, stream.recvIdx, "final batch ack must be consumed")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_EmptySnapshot(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
|
||||
|
||||
s := newSnapshotTestServer(t, 500)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1, "empty snapshot must still send sync-complete")
|
||||
assert.Empty(t, stream.sent[0].Mapping)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
assert.Equal(t, 1, stream.recvIdx, "empty snapshot ack must be consumed")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_MissingAckReturnsError(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4 // 2 batches → 1 ack needed, but we provide none
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
// No acks available — Recv will return error.
|
||||
stream := &syncRecordingStream{}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "receive ack")
|
||||
// First batch should have been sent before the error.
|
||||
require.Len(t, stream.sent, 1)
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_WrongMessageInsteadOfAck(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
// Send an init message instead of an ack.
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{
|
||||
{Msg: &proto.SyncMappingsRequest_Init{Init: &proto.SyncMappingsInit{ProxyId: "bad"}}},
|
||||
},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expected ack")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_BackPressureOrdering(t *testing.T) {
|
||||
// Verify batches are sent strictly sequentially — batch N+1 is not sent
|
||||
// until the ack for batch N is received, including the final batch.
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 6 // 3 batches, 3 acks
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
var mu sync.Mutex
|
||||
var events []string
|
||||
|
||||
// Build a stream that logs send/recv events so we can verify ordering.
|
||||
ackCh := make(chan struct{}, 3)
|
||||
stream := &orderTrackingStream{
|
||||
mu: &mu,
|
||||
events: &events,
|
||||
ackCh: ackCh,
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
// Feed acks asynchronously after a short delay to simulate real proxy.
|
||||
go func() {
|
||||
for range 3 {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
ackCh <- struct{}{}
|
||||
}
|
||||
}()
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Expected: send, recv-ack, send, recv-ack, send, recv-ack.
|
||||
require.Len(t, events, 6)
|
||||
assert.Equal(t, "send", events[0])
|
||||
assert.Equal(t, "recv", events[1])
|
||||
assert.Equal(t, "send", events[2])
|
||||
assert.Equal(t, "recv", events[3])
|
||||
assert.Equal(t, "send", events[4])
|
||||
assert.Equal(t, "recv", events[5])
|
||||
}
|
||||
|
||||
// orderTrackingStream logs "send" and "recv" events and blocks Recv until
|
||||
// an ack is signaled via ackCh.
|
||||
type orderTrackingStream struct {
|
||||
grpc.ServerStream
|
||||
mu *sync.Mutex
|
||||
events *[]string
|
||||
ackCh chan struct{}
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Send(_ *proto.SyncMappingsResponse) error {
|
||||
s.mu.Lock()
|
||||
*s.events = append(*s.events, "send")
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
<-s.ackCh
|
||||
s.mu.Lock()
|
||||
*s.events = append(*s.events, "recv")
|
||||
s.mu.Unlock()
|
||||
return ackMsg(), nil
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Context() context.Context { return context.Background() }
|
||||
func (s *orderTrackingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *orderTrackingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *orderTrackingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *orderTrackingStream) SendMsg(any) error { return nil }
|
||||
func (s *orderTrackingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestSendSnapshotSync_TokensGeneratedPerBatch(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4
|
||||
const ttl = 100 * time.Millisecond
|
||||
const ackDelay = 200 * time.Millisecond
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
s.tokenTTL = ttl
|
||||
|
||||
// Build a stream that validates tokens immediately on Send, then
|
||||
// delays the ack to ensure the next batch's tokens are generated fresh.
|
||||
var validateErrs []error
|
||||
ackCh := make(chan struct{}, 2)
|
||||
stream := &tokenValidatingSyncStream{
|
||||
tokenStore: s.tokenStore,
|
||||
validateErrs: &validateErrs,
|
||||
ackCh: ackCh,
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Delay first ack so that if tokens were all generated upfront they'd expire.
|
||||
time.Sleep(ackDelay)
|
||||
ackCh <- struct{}{}
|
||||
// Final batch ack — immediate.
|
||||
ackCh <- struct{}{}
|
||||
}()
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, validateErrs,
|
||||
"tokens must remain valid: per-batch generation guarantees freshness")
|
||||
}
|
||||
|
||||
type tokenValidatingSyncStream struct {
|
||||
grpc.ServerStream
|
||||
tokenStore *OneTimeTokenStore
|
||||
validateErrs *[]error
|
||||
ackCh chan struct{}
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Send(m *proto.SyncMappingsResponse) error {
|
||||
for _, mapping := range m.Mapping {
|
||||
if err := s.tokenStore.ValidateAndConsume(mapping.AuthToken, mapping.AccountId, mapping.Id); err != nil {
|
||||
*s.validateErrs = append(*s.validateErrs, fmt.Errorf("svc %s: %w", mapping.Id, err))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
<-s.ackCh
|
||||
return ackMsg(), nil
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Context() context.Context { return context.Background() }
|
||||
func (s *tokenValidatingSyncStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) SetTrailer(metadata.MD) {}
|
||||
func (s *tokenValidatingSyncStream) SendMsg(any) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestConnectionSendResponse_RoutesToSyncStream(t *testing.T) {
|
||||
stream := &syncRecordingStream{}
|
||||
conn := &proxyConnection{
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
resp := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Id: "svc-1", AccountId: "acct-1", Domain: "example.com"},
|
||||
},
|
||||
InitialSyncComplete: true,
|
||||
}
|
||||
|
||||
err := conn.sendResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1)
|
||||
assert.Len(t, stream.sent[0].Mapping, 1)
|
||||
assert.Equal(t, "svc-1", stream.sent[0].Mapping[0].Id)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestConnectionSendResponse_RoutesToLegacyStream(t *testing.T) {
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
resp := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Id: "svc-2", AccountId: "acct-2"},
|
||||
},
|
||||
}
|
||||
|
||||
err := conn.sendResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.messages, 1)
|
||||
assert.Equal(t, "svc-2", stream.messages[0].Mapping[0].Id)
|
||||
}
|
||||
@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
|
||||
|
||||
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
||||
t.Helper()
|
||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
|
||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
return token
|
||||
}
|
||||
@@ -125,6 +125,7 @@ func TestValidateSession_UserAllowed(t *testing.T) {
|
||||
assert.True(t, resp.Valid, "User should be allowed access")
|
||||
assert.Equal(t, "allowedUserId", resp.UserId)
|
||||
assert.Empty(t, resp.DeniedReason)
|
||||
assert.Equal(t, []string{"allowedGroupId"}, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's group memberships")
|
||||
}
|
||||
|
||||
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||
@@ -145,6 +146,7 @@ func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||
assert.False(t, resp.Valid, "User not in group should be denied")
|
||||
assert.Equal(t, "not_in_group", resp.DeniedReason)
|
||||
assert.Equal(t, "nonGroupUserId", resp.UserId)
|
||||
assert.Empty(t, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's actual (empty) memberships on denial")
|
||||
}
|
||||
|
||||
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
|
||||
|
||||
@@ -2178,7 +2178,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
||||
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key,
|
||||
mode, listen_port, port_auto_assigned, source, source_peer, terminated
|
||||
mode, listen_port, port_auto_assigned, source, source_peer, terminated,
|
||||
private, access_groups
|
||||
FROM services WHERE account_id = $1`
|
||||
|
||||
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
|
||||
@@ -2193,10 +2194,11 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
|
||||
var s rpservice.Service
|
||||
var auth []byte
|
||||
var accessGroups []byte
|
||||
var createdAt, certIssuedAt sql.NullTime
|
||||
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
||||
var mode, source, sourcePeer sql.NullString
|
||||
var terminated, portAutoAssigned sql.NullBool
|
||||
var terminated, portAutoAssigned, private sql.NullBool
|
||||
var listenPort sql.NullInt64
|
||||
err := row.Scan(
|
||||
&s.ID,
|
||||
@@ -2219,6 +2221,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
&source,
|
||||
&sourcePeer,
|
||||
&terminated,
|
||||
&private,
|
||||
&accessGroups,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2230,6 +2234,16 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
}
|
||||
}
|
||||
|
||||
if len(accessGroups) > 0 {
|
||||
if err := json.Unmarshal(accessGroups, &s.AccessGroups); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal access_groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if private.Valid {
|
||||
s.Private = private.Bool
|
||||
}
|
||||
|
||||
s.Meta = rpservice.Meta{}
|
||||
if createdAt.Valid {
|
||||
s.Meta.CreatedAt = createdAt.Time
|
||||
@@ -5826,6 +5840,7 @@ var validCapabilityColumns = map[string]struct{}{
|
||||
"supports_custom_ports": {},
|
||||
"require_subdomain": {},
|
||||
"supports_crowdsec": {},
|
||||
"private": {},
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||
@@ -5840,6 +5855,12 @@ func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr s
|
||||
return s.getClusterCapability(ctx, clusterAddr, "require_subdomain")
|
||||
}
|
||||
|
||||
// GetClusterSupportsPrivate reports whether any active proxy in the cluster
|
||||
// has the private capability (nil = unreported).
|
||||
func (s *SqlStore) GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
return s.getClusterCapability(ctx, clusterAddr, "private")
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec returns whether all active proxies in the cluster
|
||||
// have CrowdSec configured. Returns nil when no proxy reported the capability.
|
||||
// Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec
|
||||
|
||||
46
management/server/store/sql_store_service_test.go
Normal file
46
management/server/store/sql_store_service_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
)
|
||||
|
||||
func TestSqlStore_GetAccount_PrivateServiceRoundtrip(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
ctx := context.Background()
|
||||
account := newAccountWithId(ctx, "account_private_svc", "testuser", "")
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
svc := &rpservice.Service{
|
||||
ID: "svc-private",
|
||||
AccountID: account.Id,
|
||||
Name: "private-svc",
|
||||
Domain: "private.example",
|
||||
ProxyCluster: "cluster.example",
|
||||
Enabled: true,
|
||||
Mode: rpservice.ModeHTTP,
|
||||
Private: true,
|
||||
AccessGroups: []string{"grp-admins", "grp-ops"},
|
||||
}
|
||||
require.NoError(t, store.CreateService(ctx, svc))
|
||||
|
||||
loaded, err := store.GetAccount(ctx, account.Id)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, loaded.Services, 1)
|
||||
|
||||
got := loaded.Services[0]
|
||||
assert.True(t, got.Private)
|
||||
assert.Equal(t, []string{"grp-admins", "grp-ops"}, got.AccessGroups)
|
||||
})
|
||||
}
|
||||
@@ -312,6 +312,7 @@ type Store interface {
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
|
||||
@@ -1461,6 +1461,20 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetClusterSupportsPrivate mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsPrivate", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterSupportsPrivate indicates an expected call of GetClusterSupportsPrivate.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsPrivate", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsPrivate), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetCustomDomain mocks base method.
|
||||
func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -32,7 +32,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTTL = 300
|
||||
defaultTTL = 300
|
||||
// privateServiceDNSRecordTTL is short so proxy-peer changes propagate quickly to clients.
|
||||
privateServiceDNSRecordTTL = 5
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
DefaultPeerInactivityExpiration = 10 * time.Minute
|
||||
|
||||
@@ -254,6 +256,117 @@ func getUniqueHostLabel(name string, peerLabels LookupMap) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// SynthesizePrivateServiceZones returns in-memory CustomZones with A records pointing each enabled private service the peer can reach at the cluster's proxy-peer IPs. One zone per cluster (multiple services share); records gated by AccessGroups.
|
||||
func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZone {
|
||||
peer, ok := a.Peers[peerID]
|
||||
if !ok || peer == nil {
|
||||
return nil
|
||||
}
|
||||
if len(a.Services) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
proxyPeersByCluster := a.GetProxyPeers()
|
||||
if len(proxyPeersByCluster) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
zonesByCluster := map[string]*nbdns.CustomZone{}
|
||||
|
||||
for _, svc := range a.Services {
|
||||
if svc == nil || !svc.Enabled || !svc.Private {
|
||||
continue
|
||||
}
|
||||
if len(svc.AccessGroups) == 0 {
|
||||
continue
|
||||
}
|
||||
if !peerInDistributionGroups(peerGroups, svc.AccessGroups) {
|
||||
continue
|
||||
}
|
||||
proxyPeers := proxyPeersByCluster[svc.ProxyCluster]
|
||||
if len(proxyPeers) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
zone, exists := zonesByCluster[svc.ProxyCluster]
|
||||
if !exists {
|
||||
// NonAuthoritative makes this a match-only zone: queries for
|
||||
// names without an explicit record fall through to the
|
||||
// upstream resolver instead of returning NXDOMAIN. Without
|
||||
// it, adding a single private service would black-hole every
|
||||
// other name under the cluster apex.
|
||||
zone = &nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(svc.ProxyCluster),
|
||||
Records: []nbdns.SimpleRecord{},
|
||||
NonAuthoritative: true,
|
||||
}
|
||||
zonesByCluster[svc.ProxyCluster] = zone
|
||||
}
|
||||
|
||||
emitted := 0
|
||||
skippedDisconnected := 0
|
||||
for _, p := range proxyPeers {
|
||||
if p == nil || !p.IP.IsValid() {
|
||||
continue
|
||||
}
|
||||
// Only emit a record when the proxy peer is actually
|
||||
// connected. A disconnected proxy peer's tunnel IP won't
|
||||
// answer; pointing DNS at it would produce a black hole
|
||||
// for as long as the record is cached client-side.
|
||||
if p.Status == nil || !p.Status.Connected {
|
||||
skippedDisconnected++
|
||||
continue
|
||||
}
|
||||
zone.Records = append(zone.Records, nbdns.SimpleRecord{
|
||||
Name: dns.Fqdn(svc.Domain),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: privateServiceDNSRecordTTL,
|
||||
RData: p.IP.String(),
|
||||
})
|
||||
emitted++
|
||||
}
|
||||
// Disagreement with the firewall path is the typical
|
||||
// "domain doesn't reach client but firewall rules do"
|
||||
// symptom: the synth service is otherwise fine, only the
|
||||
// proxy peer's persisted Connected flag is wrong (most
|
||||
// likely the connection reaper marked it disconnected even
|
||||
// though the gRPC stream is alive).
|
||||
if emitted == 0 && skippedDisconnected > 0 {
|
||||
log.Debugf("private-zone synth: svc %s domain=%s cluster=%s emitted_zero proxy_peers=%d all_disconnected=%d (firewall would still fire)",
|
||||
svc.ID, svc.Domain, svc.ProxyCluster, len(proxyPeers), skippedDisconnected)
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
|
||||
for _, zone := range zonesByCluster {
|
||||
if len(zone.Records) == 0 {
|
||||
continue
|
||||
}
|
||||
out = append(out, *zone)
|
||||
}
|
||||
if len(out) == 0 && len(a.Services) > 0 {
|
||||
// Targeted diagnostic for the "firewall yes, DNS no" divergence —
|
||||
// fires only when services exist but synth returns zero zones,
|
||||
// so accounts without private services produce no noise.
|
||||
log.Debugf("private-zone synth: peer %s account %s returned 0 zones from %d candidate service(s)",
|
||||
peerID, a.Id, len(a.Services))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// peerInDistributionGroups reports whether any of the peer's groups
|
||||
// matches the service's bearer-auth distribution_groups.
|
||||
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {
|
||||
for _, gid := range distributionGroups {
|
||||
if _, ok := peerGroups[gid]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
|
||||
var merr *multierror.Error
|
||||
|
||||
@@ -1498,6 +1611,53 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *servi
|
||||
a.injectTargetProxyPolicies(ctx, service, target, proxyPeers)
|
||||
}
|
||||
|
||||
a.injectPrivateServicePolicies(service, proxyPeers)
|
||||
}
|
||||
|
||||
// injectPrivateServicePolicies synthesises an in-memory ACL: AccessGroups → cluster proxy peers on TCP 80/443.
|
||||
func (a *Account) injectPrivateServicePolicies(svc *service.Service, proxyPeers []*nbpeer.Peer) {
|
||||
if !svc.Private {
|
||||
return
|
||||
}
|
||||
if len(svc.AccessGroups) == 0 {
|
||||
return
|
||||
}
|
||||
if len(proxyPeers) == 0 {
|
||||
return
|
||||
}
|
||||
for _, proxyPeer := range proxyPeers {
|
||||
a.Policies = append(a.Policies, a.createPrivateServicePolicy(svc, proxyPeer))
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) createPrivateServicePolicy(svc *service.Service, proxyPeer *nbpeer.Peer) *Policy {
|
||||
policyID := fmt.Sprintf("private-access-%s-%s", svc.ID, proxyPeer.ID)
|
||||
sources := append([]string(nil), svc.AccessGroups...)
|
||||
return &Policy{
|
||||
ID: policyID,
|
||||
Name: fmt.Sprintf("Private Access to %s", svc.Name),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: policyID,
|
||||
PolicyID: policyID,
|
||||
Name: fmt.Sprintf("Allow access groups to reach %s", svc.Name),
|
||||
Enabled: true,
|
||||
Sources: sources,
|
||||
DestinationResource: Resource{
|
||||
ID: proxyPeer.ID,
|
||||
Type: ResourceTypePeer,
|
||||
},
|
||||
Bidirectional: false,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
{Start: 443, End: 443},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
|
||||
|
||||
@@ -119,6 +119,7 @@ func (a *Account) GetPeerNetworkMapComponents(
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||
components.AccountZones = append(components.AccountZones, a.SynthesizePrivateServiceZones(peerID)...)
|
||||
|
||||
for _, nsGroup := range a.NameServerGroups {
|
||||
if nsGroup.Enabled {
|
||||
|
||||
85
management/server/types/account_private_netmap_test.go
Normal file
85
management/server/types/account_private_netmap_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func TestPrivateService_NetworkMap_UserPeer_AndProxyPeer(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["user-peer"].Meta.WtVersion = "0.50.0"
|
||||
account.Peers["proxy-peer"].Meta.WtVersion = "0.50.0"
|
||||
|
||||
ctx := context.Background()
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
validated := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
}
|
||||
|
||||
t.Run("user-peer update", func(t *testing.T) {
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "user-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
zone, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok)
|
||||
require.Len(t, zone.Records, 1)
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", zone.Records[0].Name)
|
||||
assert.Equal(t, int(dns.TypeA), zone.Records[0].Type)
|
||||
assert.Equal(t, "100.64.0.99", zone.Records[0].RData)
|
||||
|
||||
assert.Contains(t, netmapPeerIDs(nm.Peers), "proxy-peer")
|
||||
assertPrivateServiceFirewallRules(t, nm.FirewallRules, "100.64.0.99", FirewallRuleDirectionOUT)
|
||||
})
|
||||
|
||||
t.Run("proxy-peer update", func(t *testing.T) {
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "proxy-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
assert.Contains(t, netmapPeerIDs(nm.Peers), "user-peer")
|
||||
assertPrivateServiceFirewallRules(t, nm.FirewallRules, "100.64.0.10", FirewallRuleDirectionIN)
|
||||
})
|
||||
}
|
||||
|
||||
func netmapPeerIDs(peers []*nbpeer.Peer) []string {
|
||||
ids := make([]string, 0, len(peers))
|
||||
for _, p := range peers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, p.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func assertPrivateServiceFirewallRules(t *testing.T, rules []*FirewallRule, peerIP string, direction int) {
|
||||
t.Helper()
|
||||
wantPorts := map[uint16]bool{80: false, 443: false}
|
||||
for _, r := range rules {
|
||||
if r == nil || r.PeerIP != peerIP || r.Direction != direction {
|
||||
continue
|
||||
}
|
||||
if r.Protocol != string(PolicyRuleProtocolTCP) || r.Action != string(PolicyTrafficActionAccept) {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case r.PortRange.Start == r.PortRange.End && r.PortRange.Start != 0:
|
||||
wantPorts[r.PortRange.Start] = true
|
||||
case r.Port == "80":
|
||||
wantPorts[80] = true
|
||||
case r.Port == "443":
|
||||
wantPorts[443] = true
|
||||
}
|
||||
}
|
||||
for port, found := range wantPorts {
|
||||
assert.Truef(t, found, "missing TCP accept rule on port %d for peer %s direction %d", port, peerIP, direction)
|
||||
}
|
||||
}
|
||||
256
management/server/types/account_private_zones_test.go
Normal file
256
management/server/types/account_private_zones_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func privateZoneTestAccount(t *testing.T) *Account {
|
||||
t.Helper()
|
||||
return &Account{
|
||||
Id: "acct-1",
|
||||
Settings: &Settings{},
|
||||
Network: &Network{
|
||||
Identifier: "net-1",
|
||||
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"user-peer": {
|
||||
ID: "user-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "user-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.10"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
},
|
||||
"proxy-peer": {
|
||||
ID: "proxy-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.99"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
ProxyMeta: nbpeer.ProxyMeta{
|
||||
Embedded: true,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"grp-admins": {
|
||||
ID: "grp-admins",
|
||||
Name: "admins",
|
||||
Peers: []string{"user-peer"},
|
||||
},
|
||||
},
|
||||
Services: []*service.Service{
|
||||
{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "myapp",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_PeerInGroup_GetsRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "one cluster should produce one zone")
|
||||
zone := zones[0]
|
||||
assert.Equal(t, "eu.proxy.netbird.io.", zone.Domain, "zone apex must be the cluster FQDN")
|
||||
assert.True(t, zone.NonAuthoritative, "synth zone must be match-only so unrelated sibling names fall through to the upstream resolver")
|
||||
require.Len(t, zone.Records, 1, "one private service yields one A record")
|
||||
rec := zone.Records[0]
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", rec.Name, "record name is the service FQDN")
|
||||
assert.Equal(t, int(dns.TypeA), rec.Type, "record type must be A")
|
||||
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
|
||||
assert.Equal(t, privateServiceDNSRecordTTL, rec.TTL, "TTL must match the synth-records constant")
|
||||
assert.Equal(t, nbdns.DefaultClass, rec.Class, "record class must be the package default")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_PeerNotInGroup_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Groups["grp-admins"].Peers = nil
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "peer outside distribution_groups must not see private-service records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_NotPrivate_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services[0].Private = false
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "non-private service must not produce DNS records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_NoAccessGroups_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services[0].AccessGroups = nil
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "private service without bearer auth must not produce DNS records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_NoProxyPeers_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
delete(account.Peers, "proxy-peer")
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "no embedded proxy peer in cluster means no record to emit")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_DisabledService_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services[0].Enabled = false
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "disabled service must not produce DNS records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_DisconnectedProxyPeer_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["proxy-peer"].Status = &nbpeer.PeerStatus{Connected: false}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "disconnected proxy peer must not produce a DNS record (would be a black hole)")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_PartiallyDisconnectedProxyPeers_OnlyConnectedSurface(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["proxy-peer-2"] = &nbpeer.Peer{
|
||||
ID: "proxy-peer-2",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-2-key",
|
||||
IP: netip.MustParseAddr("100.64.0.100"),
|
||||
Status: &nbpeer.PeerStatus{Connected: false},
|
||||
ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "eu.proxy.netbird.io"},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1)
|
||||
require.Len(t, zones[0].Records, 1, "only the connected proxy peer must surface")
|
||||
assert.Equal(t, "100.64.0.99", zones[0].Records[0].RData)
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_MultipleProxyPeers_RoundRobin(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["proxy-peer-2"] = &nbpeer.Peer{
|
||||
ID: "proxy-peer-2",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-2-key",
|
||||
IP: netip.MustParseAddr("100.64.0.100"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "eu.proxy.netbird.io"},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "still one cluster yields one zone")
|
||||
require.Len(t, zones[0].Records, 2, "two proxy peers must produce two A records on the same name")
|
||||
rdata := []string{zones[0].Records[0].RData, zones[0].Records[1].RData}
|
||||
assert.ElementsMatch(t, []string{"100.64.0.99", "100.64.0.100"}, rdata, "both proxy peer IPs must surface")
|
||||
}
|
||||
|
||||
// findCustomZone returns the CustomZone whose Domain equals the FQDN
|
||||
// of want, or a zero value when not found. Tests use it to assert
|
||||
// that the synth zone reaches dnsUpdate.CustomZones end-to-end.
|
||||
func findCustomZone(zones []nbdns.CustomZone, want string) (nbdns.CustomZone, bool) {
|
||||
wantFqdn := dns.Fqdn(want)
|
||||
for _, z := range zones {
|
||||
if z.Domain == wantFqdn {
|
||||
return z, true
|
||||
}
|
||||
}
|
||||
return nbdns.CustomZone{}, false
|
||||
}
|
||||
|
||||
// TestPrivateZone_GetPeerNetworkMapFromComponents_ShipsSynthZone
|
||||
// covers the components-based builder path. The components builder
|
||||
// appends SynthesizePrivateServiceZones to AccountZones; the
|
||||
// CalculateNetworkMapFromComponents step then merges AccountZones
|
||||
// into dnsUpdate.CustomZones.
|
||||
func TestPrivateZone_GetPeerNetworkMapFromComponents_ShipsSynthZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
ctx := context.Background()
|
||||
validated := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
}
|
||||
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "user-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm, "network map must be produced for an in-account peer")
|
||||
|
||||
zone, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok, "shipped CustomZones must include the synth zone for the cluster")
|
||||
require.Len(t, zone.Records, 1, "exactly one record per private service per connected proxy peer")
|
||||
rec := zone.Records[0]
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", rec.Name, "record name is the service FQDN")
|
||||
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
|
||||
}
|
||||
|
||||
// TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone
|
||||
// confirms the negative case the user encountered: a peer whose
|
||||
// groups don't overlap the policy's distribution_groups gets a
|
||||
// network map with no synth zone (and the wildcard / peer zones still
|
||||
// flow through). This is the test mirror of the runtime confusion
|
||||
// where the user looked at a non-distribution-group peer and assumed
|
||||
// the synth path was broken.
|
||||
func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["outsider"] = &nbpeer.Peer{
|
||||
ID: "outsider",
|
||||
AccountID: "acct-1",
|
||||
Key: "outsider-key",
|
||||
IP: netip.MustParseAddr("100.64.0.20"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
}
|
||||
ctx := context.Background()
|
||||
validated := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
"outsider": {},
|
||||
}
|
||||
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "outsider", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
_, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
|
||||
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
ID: "svc-2",
|
||||
AccountID: "acct-1",
|
||||
Name: "anotherapp",
|
||||
Domain: "anotherapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
})
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "two services on the same cluster must collapse into one zone")
|
||||
require.Len(t, zones[0].Records, 2, "two services yield two A records")
|
||||
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
|
||||
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package types
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -82,9 +84,9 @@ func setupTestAccount() *Account {
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"groupAll": {
|
||||
ID: "groupAll",
|
||||
Name: "All",
|
||||
Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"},
|
||||
ID: "groupAll",
|
||||
Name: "All",
|
||||
Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"},
|
||||
Issued: GroupIssuedAPI,
|
||||
},
|
||||
"group1": {
|
||||
@@ -1583,3 +1585,203 @@ func Test_filterPeerAppliedZones(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_ProxyPeerGetsInboundRule(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
userPeerIP := netip.MustParseAddr("100.64.0.10")
|
||||
proxyPeerIP := netip.MustParseAddr("100.64.0.99")
|
||||
|
||||
account := &Account{
|
||||
Id: "acct-1",
|
||||
Network: &Network{
|
||||
Identifier: "net-1",
|
||||
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"user-peer": {
|
||||
ID: "user-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "user-peer-key",
|
||||
IP: userPeerIP,
|
||||
},
|
||||
"proxy-peer": {
|
||||
ID: "proxy-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-key",
|
||||
IP: proxyPeerIP,
|
||||
ProxyMeta: nbpeer.ProxyMeta{
|
||||
Embedded: true,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"grp-admins": {
|
||||
ID: "grp-admins",
|
||||
Name: "admins",
|
||||
Peers: []string{"user-peer"},
|
||||
},
|
||||
},
|
||||
Services: []*service.Service{
|
||||
{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "myapp",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
Targets: []*service.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: service.TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
var found *Policy
|
||||
for _, p := range account.Policies {
|
||||
if p != nil && p.ID == "private-access-svc-1-proxy-peer" {
|
||||
found = p
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, found, "expected synthesised private-access policy in account.Policies")
|
||||
require.Len(t, found.Rules, 1, "policy should have exactly one rule")
|
||||
rule := found.Rules[0]
|
||||
assert.Equal(t, []string{"grp-admins"}, rule.Sources, "sources should be group IDs verbatim")
|
||||
assert.Equal(t, "proxy-peer", rule.DestinationResource.ID, "destination resource should be the proxy peer ID")
|
||||
assert.Equal(t, ResourceTypePeer, rule.DestinationResource.Type, "destination resource type should be peer")
|
||||
|
||||
validatedPeersMap := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
}
|
||||
|
||||
proxyPeer := account.Peers["proxy-peer"]
|
||||
aclPeers, firewallRules, _, _ := account.GetPeerConnectionResources(ctx, proxyPeer, validatedPeersMap, nil)
|
||||
|
||||
var sawUserAsAclPeer bool
|
||||
for _, p := range aclPeers {
|
||||
if p.ID == "user-peer" {
|
||||
sawUserAsAclPeer = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, sawUserAsAclPeer, "proxy peer should see the user peer as an ACL peer")
|
||||
|
||||
var inboundRules []*FirewallRule
|
||||
for _, r := range firewallRules {
|
||||
if r.Direction == FirewallRuleDirectionIN && r.PeerIP == userPeerIP.String() {
|
||||
inboundRules = append(inboundRules, r)
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, inboundRules, "proxy peer should have inbound firewall rules from the user peer")
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_NotPrivate_NoPolicy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
account := privateServiceTestAccount(t)
|
||||
account.Services[0].Private = false
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "non-private service must not synthesise an access policy")
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_EmptyAccessGroups_NoPolicy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
account := privateServiceTestAccount(t)
|
||||
account.Services[0].AccessGroups = nil
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "private service with no access groups must not synthesise a policy")
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_NoProxyPeers_NoPolicy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
account := privateServiceTestAccount(t)
|
||||
delete(account.Peers, "proxy-peer")
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "policy must not synthesise when the cluster has no proxy peers")
|
||||
}
|
||||
|
||||
func privateServiceTestAccount(t *testing.T) *Account {
|
||||
t.Helper()
|
||||
return &Account{
|
||||
Id: "acct-1",
|
||||
Network: &Network{
|
||||
Identifier: "net-1",
|
||||
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"user-peer": {
|
||||
ID: "user-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "user-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.10"),
|
||||
},
|
||||
"proxy-peer": {
|
||||
ID: "proxy-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.99"),
|
||||
ProxyMeta: nbpeer.ProxyMeta{
|
||||
Embedded: true,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"grp-admins": {
|
||||
ID: "grp-admins",
|
||||
Name: "admins",
|
||||
Peers: []string{"user-peer"},
|
||||
},
|
||||
},
|
||||
Services: []*service.Service{
|
||||
{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "myapp",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
Targets: []*service.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: service.TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func hasPrivateAccessPolicy(account *Account, serviceID string) bool {
|
||||
prefix := "private-access-" + serviceID + "-"
|
||||
for _, p := range account.Policies {
|
||||
if p != nil && len(p.ID) > len(prefix) && p.ID[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -10,6 +10,13 @@ import (
|
||||
|
||||
type Manager interface {
|
||||
GetUser(ctx context.Context, userID string) (*types.User, error)
|
||||
// GetUserWithGroups returns the user and the *types.Group
|
||||
// records for the user's AutoGroups, in the same order as
|
||||
// AutoGroups. Group ids that don't resolve to a stored group
|
||||
// are skipped from the returned slice (the parallel id list is
|
||||
// derivable from the returned User). Wraps two store calls
|
||||
// today; can be optimised to a single JOIN later if needed.
|
||||
GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -29,6 +36,27 @@ func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User,
|
||||
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(user.AutoGroups) == 0 {
|
||||
return user, nil, nil
|
||||
}
|
||||
groupsMap, err := m.store.GetGroupsByIDs(ctx, store.LockingStrengthNone, user.AccountID, user.AutoGroups)
|
||||
if err != nil {
|
||||
return user, nil, err
|
||||
}
|
||||
groups := make([]*types.Group, 0, len(user.AutoGroups))
|
||||
for _, id := range user.AutoGroups {
|
||||
if g, ok := groupsMap[id]; ok && g != nil {
|
||||
groups = append(groups, g)
|
||||
}
|
||||
}
|
||||
return user, groups, nil
|
||||
}
|
||||
|
||||
func NewManagerMock() Manager {
|
||||
return &managerMock{}
|
||||
}
|
||||
@@ -47,3 +75,11 @@ func (m *managerMock) GetUser(ctx context.Context, userID string) (*types.User,
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerMock) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||
user, err := m.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return user, nil, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user