mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-22 16:49:58 +00:00
Compare commits
15 Commits
task/align
...
feat/priva
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0fd5c59ce3 | ||
|
|
a423b788c0 | ||
|
|
3aa62e31a6 | ||
|
|
43c7b4dc0b | ||
|
|
f7dff43e34 | ||
|
|
717c2b493d | ||
|
|
627ee71fa8 | ||
|
|
b21a91a507 | ||
|
|
06cc488e90 | ||
|
|
dd90c0d180 | ||
|
|
3928cf93ce | ||
|
|
564302f688 | ||
|
|
43d32ff17b | ||
|
|
036e91cdea | ||
|
|
167ee08e14 |
@@ -405,6 +405,21 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||
// roster — callers should treat that as "unknown peer".
|
||||
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if !ip.IsValid() || c.recorder == nil {
|
||||
return "", "", false
|
||||
}
|
||||
state, found := c.recorder.PeerStateByIP(ip.String())
|
||||
if !found {
|
||||
return "", "", false
|
||||
}
|
||||
return state.PubKey, state.FQDN, true
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
|
||||
@@ -305,6 +305,25 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||
// address so dual-stack peers are reachable on either family. Returns the
|
||||
// zero State and false when no peer matches or the input is empty.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
}
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
for _, state := range d.peers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
return State{}, false
|
||||
}
|
||||
|
||||
// RemovePeer removes peer from Daemon status map
|
||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
|
||||
@@ -63,6 +63,33 @@ func TestUpdatePeerState(t *testing.T) {
|
||||
assert.Equal(t, ip, state.IP, "ip should be equal")
|
||||
}
|
||||
|
||||
func TestStatus_PeerStateByIP(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
|
||||
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
|
||||
|
||||
state, ok := status.PeerStateByIP("100.64.0.10")
|
||||
req.True(ok, "known tunnel IP should resolve to a peer state")
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
|
||||
|
||||
_, ok = status.PeerStateByIP("100.64.0.99")
|
||||
req.False(ok, "unknown IP must report ok=false")
|
||||
}
|
||||
|
||||
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
|
||||
|
||||
state, ok := status.PeerStateByIP("fd00::1")
|
||||
req.True(ok, "IPv6-only match must resolve to the peer state")
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
}
|
||||
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
key := "abc"
|
||||
fqdn := "peer-a.netbird.local"
|
||||
|
||||
@@ -5,6 +5,7 @@ package peers
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
@@ -35,6 +36,14 @@ type Manager interface {
|
||||
SetAccountManager(accountManager account.Manager)
|
||||
GetPeerID(ctx context.Context, peerKey string) (string, error)
|
||||
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
|
||||
// GetPeerByTunnelIP looks up a peer in accountID by its WireGuard tunnel IP.
|
||||
// Returns nil with an error when no match exists. No permission check;
|
||||
// callers (the proxy's ValidateTunnelPeer RPC) are trusted server components.
|
||||
GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error)
|
||||
// GetPeerWithGroups returns the peer and the list of *types.Group it belongs
|
||||
// to. Used by the proxy's auth path to authorise a request by the calling
|
||||
// peer's group memberships.
|
||||
GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -99,6 +108,26 @@ func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string,
|
||||
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP delegates to the store's indexed lookup.
|
||||
func (m *managerImpl) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||
return m.store.GetPeerByIP(ctx, store.LockingStrengthNone, accountID, ip)
|
||||
}
|
||||
|
||||
// GetPeerWithGroups returns the peer plus its group memberships. Any store
|
||||
// error returns (nil, nil, err) so callers never receive a valid peer
|
||||
// alongside a non-nil error.
|
||||
func (m *managerImpl) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||
p, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
groups, err := m.store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return p, groups, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ package peers
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
account "github.com/netbirdio/netbird/management/server/account"
|
||||
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
types "github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
@@ -38,6 +40,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CreateProxyPeer mocks base method.
|
||||
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID, peerKey, cluster string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||
}
|
||||
|
||||
// DeletePeers mocks base method.
|
||||
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -97,6 +113,21 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP mocks base method.
|
||||
func (m *MockManager) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerByTunnelIP", ctx, accountID, ip)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeerByTunnelIP indicates an expected call of GetPeerByTunnelIP.
|
||||
func (mr *MockManagerMockRecorder) GetPeerByTunnelIP(ctx, accountID, ip interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerByTunnelIP", reflect.TypeOf((*MockManager)(nil).GetPeerByTunnelIP), ctx, accountID, ip)
|
||||
}
|
||||
|
||||
// GetPeerID mocks base method.
|
||||
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -112,6 +143,22 @@ func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
|
||||
}
|
||||
|
||||
// GetPeerWithGroups mocks base method.
|
||||
func (m *MockManager) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerWithGroups", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].([]*types.Group)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetPeerWithGroups indicates an expected call of GetPeerWithGroups.
|
||||
func (mr *MockManagerMockRecorder) GetPeerWithGroups(ctx, accountID, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerWithGroups", reflect.TypeOf((*MockManager)(nil).GetPeerWithGroups), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// GetPeersByGroupIDs mocks base method.
|
||||
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -162,17 +209,3 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||
}
|
||||
|
||||
// CreateProxyPeer mocks base method.
|
||||
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
|
||||
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
|
||||
}
|
||||
|
||||
@@ -23,6 +23,8 @@ type Domain struct {
|
||||
// SupportsCrowdSec is populated at query time from proxy cluster capabilities.
|
||||
// Not persisted.
|
||||
SupportsCrowdSec *bool `gorm:"-"`
|
||||
// SupportsPrivate is populated at query time from proxy cluster capabilities. Not persisted.
|
||||
SupportsPrivate *bool `gorm:"-"`
|
||||
}
|
||||
|
||||
// EventMeta returns activity event metadata for a domain
|
||||
|
||||
@@ -49,6 +49,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
||||
SupportsCustomPorts: d.SupportsCustomPorts,
|
||||
RequireSubdomain: d.RequireSubdomain,
|
||||
SupportsCrowdsec: d.SupportsCrowdSec,
|
||||
SupportsPrivate: d.SupportsPrivate,
|
||||
}
|
||||
if d.TargetCluster != "" {
|
||||
resp.TargetCluster = &d.TargetCluster
|
||||
|
||||
@@ -35,6 +35,7 @@ type proxyManager interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -93,6 +94,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
|
||||
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
|
||||
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
|
||||
d.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, cluster)
|
||||
ret = append(ret, d)
|
||||
}
|
||||
|
||||
@@ -109,6 +111,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
if d.TargetCluster != "" {
|
||||
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
|
||||
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
|
||||
cd.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, d.TargetCluster)
|
||||
}
|
||||
// Custom domains never require a subdomain by default since
|
||||
// the account owns them and should be able to use the bare domain.
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
type mockProxyManager struct {
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
}
|
||||
|
||||
@@ -40,6 +40,10 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
@@ -151,4 +155,3 @@ func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ type Manager interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||
|
||||
@@ -21,6 +21,7 @@ type store interface {
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
@@ -137,6 +138,11 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
|
||||
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate reports whether any active proxy claims the private capability (nil = unreported).
|
||||
func (m Manager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
return m.store.GetClusterSupportsPrivate(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
@@ -178,4 +184,3 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, acco
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,16 +15,16 @@ import (
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
@@ -99,6 +99,9 @@ func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *boo
|
||||
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestManager(s store) *Manager {
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
|
||||
@@ -92,6 +92,20 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate mocks base method.
|
||||
func (m *MockManager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClusterSupportsPrivate", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterSupportsPrivate indicates an expected call of ClusterSupportsPrivate.
|
||||
func (mr *MockManagerMockRecorder) ClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsPrivate", reflect.TypeOf((*MockManager)(nil).ClusterSupportsPrivate), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -20,6 +20,9 @@ type Capabilities struct {
|
||||
RequireSubdomain *bool
|
||||
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
|
||||
SupportsCrowdsec *bool
|
||||
// Private indicates whether this proxy supports inbound access via Wireguard
|
||||
// tunnel and netbird-only authentication policies
|
||||
Private *bool
|
||||
}
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
@@ -67,10 +70,9 @@ type Cluster struct {
|
||||
Type ClusterType
|
||||
Online bool
|
||||
ConnectedProxies int
|
||||
// Capability flags. *bool because nil means "no proxy reported a
|
||||
// capability for this cluster" — the dashboard renders these as
|
||||
// unknown rather than false.
|
||||
// *bool: nil = no proxy reported the capability; the dashboard renders that as unknown.
|
||||
SupportsCustomPorts *bool
|
||||
RequireSubdomain *bool
|
||||
SupportsCrowdSec *bool
|
||||
Private *bool
|
||||
}
|
||||
|
||||
@@ -204,6 +204,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdSec,
|
||||
Private: c.Private,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ type CapabilityProvider interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -136,6 +137,7 @@ func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]
|
||||
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
|
||||
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
|
||||
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
|
||||
clusters[i].Private = m.capabilities.ClusterSupportsPrivate(ctx, clusters[i].Address)
|
||||
}
|
||||
|
||||
return clusters, nil
|
||||
@@ -208,6 +210,9 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
||||
target.Host = resource.Domain
|
||||
case service.TargetTypeSubnet:
|
||||
// For subnets we do not do any lookups on the resource
|
||||
case service.TargetTypeCluster:
|
||||
// Cluster targets carry the upstream address on target_id; the
|
||||
// proxy resolves the destination at request time.
|
||||
default:
|
||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||
}
|
||||
@@ -779,6 +784,10 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
||||
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
|
||||
return err
|
||||
}
|
||||
case service.TargetTypeCluster:
|
||||
if err := validateClusterTarget(target); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
|
||||
}
|
||||
@@ -786,6 +795,13 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateClusterTarget(target *service.Target) error {
|
||||
if !target.Options.DirectUpstream {
|
||||
return status.Errorf(status.InvalidArgument, "cluster target %s has direct upstream disabled", target.Host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
@@ -962,12 +978,14 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
|
||||
return fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
for _, s := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1344,3 +1344,66 @@ func TestValidateSubdomainRequirement(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetReferences_ClusterTargetSkipsLookup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
accountID := "test-account"
|
||||
|
||||
// No peer or resource lookups must be issued for cluster targets.
|
||||
targets := []*rpservice.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: rpservice.TargetTypeCluster,
|
||||
Options: rpservice.TargetOptions{DirectUpstream: true},
|
||||
},
|
||||
}
|
||||
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets), "cluster target must validate without store lookups")
|
||||
}
|
||||
|
||||
// TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream pins the
|
||||
// store-side check that cluster targets must opt into the host-stack dial
|
||||
// path. Without DirectUpstream the proxy would route this target through
|
||||
// the embedded NetBird client and fail on every request.
|
||||
func TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
accountID := "test-account"
|
||||
|
||||
targets := []*rpservice.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: rpservice.TargetTypeCluster,
|
||||
Host: "backend.lan",
|
||||
},
|
||||
}
|
||||
err := validateTargetReferences(ctx, mockStore, accountID, targets)
|
||||
require.Error(t, err, "cluster target without direct_upstream must be rejected")
|
||||
assert.ErrorContains(t, err, "direct upstream disabled")
|
||||
}
|
||||
|
||||
func TestReplaceHostByLookup_SkipsClusterTarget(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
accountID := "test-account"
|
||||
|
||||
mgr := &Manager{store: mockStore}
|
||||
|
||||
svc := &rpservice.Service{
|
||||
ID: "svc-1",
|
||||
AccountID: accountID,
|
||||
Targets: []*rpservice.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: rpservice.TargetTypeCluster,
|
||||
Host: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, mgr.replaceHostByLookup(ctx, accountID, svc), "cluster target must not trigger peer/resource lookup")
|
||||
assert.Equal(t, "127.0.0.1", svc.Targets[0].Host, "operator-supplied host must be preserved for cluster target")
|
||||
}
|
||||
|
||||
@@ -45,10 +45,11 @@ const (
|
||||
StatusCertificateFailed Status = "certificate_failed"
|
||||
StatusError Status = "error"
|
||||
|
||||
TargetTypePeer TargetType = "peer"
|
||||
TargetTypeHost TargetType = "host"
|
||||
TargetTypeDomain TargetType = "domain"
|
||||
TargetTypeSubnet TargetType = "subnet"
|
||||
TargetTypePeer TargetType = "peer"
|
||||
TargetTypeHost TargetType = "host"
|
||||
TargetTypeDomain TargetType = "domain"
|
||||
TargetTypeSubnet TargetType = "subnet"
|
||||
TargetTypeCluster TargetType = "cluster"
|
||||
|
||||
SourcePermanent = "permanent"
|
||||
SourceEphemeral = "ephemeral"
|
||||
@@ -60,6 +61,11 @@ type TargetOptions struct {
|
||||
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
|
||||
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||
// DirectUpstream bypasses the proxy's embedded NetBird client and dials
|
||||
// the target via the proxy host's network stack. Useful for upstreams
|
||||
// reachable without WireGuard (public APIs, LAN services, localhost
|
||||
// sidecars). Default false.
|
||||
DirectUpstream bool `json:"direct_upstream,omitempty"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
@@ -67,7 +73,7 @@ type Target struct {
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
Host string `json:"host"`
|
||||
Port uint16 `gorm:"index:idx_target_port" json:"port"`
|
||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||
@@ -200,6 +206,10 @@ type Service struct {
|
||||
Mode string `gorm:"default:'http'"`
|
||||
ListenPort uint16
|
||||
PortAutoAssigned bool
|
||||
// Private marks the service as NetBird-only: auth via ValidateTunnelPeer against AccessGroups instead of SSO. HTTP-only.
|
||||
Private bool
|
||||
// AccessGroups is the group ID allowlist for inbound peers on private services. Mutually exclusive with bearer SSO.
|
||||
AccessGroups []string `json:"access_groups,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||
@@ -299,6 +309,12 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
Mode: &mode,
|
||||
ListenPort: &listenPort,
|
||||
PortAutoAssigned: &s.PortAutoAssigned,
|
||||
Private: &s.Private,
|
||||
}
|
||||
|
||||
if len(s.AccessGroups) > 0 {
|
||||
groups := append([]string(nil), s.AccessGroups...)
|
||||
resp.AccessGroups = &groups
|
||||
}
|
||||
|
||||
if s.ProxyCluster != "" {
|
||||
@@ -308,6 +324,7 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
return resp
|
||||
}
|
||||
|
||||
// ToProtoMapping converts the service into the wire format the proxy consumes.
|
||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
||||
pathMappings := s.buildPathMappings()
|
||||
|
||||
@@ -349,6 +366,7 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
Mode: s.Mode,
|
||||
ListenPort: int32(s.ListenPort), //nolint:gosec
|
||||
Private: s.Private,
|
||||
}
|
||||
|
||||
if r := restrictionsToProto(s.Restrictions); r != nil {
|
||||
@@ -455,7 +473,8 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
|
||||
}
|
||||
|
||||
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 &&
|
||||
opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||
return nil
|
||||
}
|
||||
apiOpts := &api.ServiceTargetOptions{}
|
||||
@@ -477,17 +496,22 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
if len(opts.CustomHeaders) > 0 {
|
||||
apiOpts.CustomHeaders = &opts.CustomHeaders
|
||||
}
|
||||
if opts.DirectUpstream {
|
||||
apiOpts.DirectUpstream = &opts.DirectUpstream
|
||||
}
|
||||
return apiOpts
|
||||
}
|
||||
|
||||
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
|
||||
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
|
||||
return nil
|
||||
}
|
||||
popts := &proto.PathTargetOptions{
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
DirectUpstream: opts.DirectUpstream,
|
||||
}
|
||||
if opts.RequestTimeout != 0 {
|
||||
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||
@@ -537,6 +561,9 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
|
||||
if o.CustomHeaders != nil {
|
||||
opts.CustomHeaders = *o.CustomHeaders
|
||||
}
|
||||
if o.DirectUpstream != nil {
|
||||
opts.DirectUpstream = *o.DirectUpstream
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
@@ -551,6 +578,14 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
|
||||
if req.ListenPort != nil {
|
||||
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
|
||||
}
|
||||
if req.Private != nil {
|
||||
s.Private = *req.Private
|
||||
}
|
||||
if req.AccessGroups != nil {
|
||||
s.AccessGroups = append([]string(nil), *req.AccessGroups...)
|
||||
} else {
|
||||
s.AccessGroups = nil
|
||||
}
|
||||
|
||||
targets, err := targetsFromAPI(accountID, req.Targets)
|
||||
if err != nil {
|
||||
@@ -740,6 +775,9 @@ func (s *Service) Validate() error {
|
||||
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.validatePrivateRequirements(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch s.Mode {
|
||||
case ModeHTTP:
|
||||
@@ -753,6 +791,23 @@ func (s *Service) Validate() error {
|
||||
}
|
||||
}
|
||||
|
||||
// validatePrivateRequirements enforces the private-service contract: HTTP mode, ≥1 access group, no bearer auth.
|
||||
func (s *Service) validatePrivateRequirements() error {
|
||||
if !s.Private {
|
||||
return nil
|
||||
}
|
||||
if s.Mode != "" && s.Mode != ModeHTTP {
|
||||
return fmt.Errorf("private services only support HTTP mode, got %q", s.Mode)
|
||||
}
|
||||
if len(s.AccessGroups) == 0 {
|
||||
return errors.New("private services require at least one access group")
|
||||
}
|
||||
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||
return errors.New("private services cannot enable bearer auth (SSO): NetBird-only access and SSO are mutually exclusive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateHTTPMode() error {
|
||||
if s.Domain == "" {
|
||||
return errors.New("service domain is required")
|
||||
@@ -799,11 +854,21 @@ func (s *Service) validateHTTPTargets() error {
|
||||
for i, target := range s.Targets {
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||
// host field will be ignored
|
||||
// Host is normally overwritten by replaceHostByLookup with the
|
||||
// resolved peer IP / resource address; operator-supplied values
|
||||
// are honored only when DirectUpstream is set. Validate the
|
||||
// override here so misconfigured hosts fail fast at API time.
|
||||
if err := validateDirectUpstreamHost(i, target); err != nil {
|
||||
return err
|
||||
}
|
||||
case TargetTypeSubnet:
|
||||
if target.Host == "" {
|
||||
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||
}
|
||||
case TargetTypeCluster:
|
||||
if err := validateClusterTarget(i, target); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||
}
|
||||
@@ -821,25 +886,67 @@ func (s *Service) validateHTTPTargets() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateClusterTarget cluster targets should not have empty hosts and should have direct upstream enabled.
|
||||
func validateClusterTarget(idx int, target *Target) error {
|
||||
host := strings.TrimSpace(target.Host)
|
||||
if host == "" {
|
||||
return fmt.Errorf("target %d: has empty host", idx)
|
||||
}
|
||||
if !target.Options.DirectUpstream {
|
||||
return fmt.Errorf("target %d: %s has direct upstream disabled", idx, target.Host)
|
||||
}
|
||||
return validateDirectUpstreamHost(idx, target)
|
||||
}
|
||||
|
||||
// validateDirectUpstreamHost validates the operator-supplied Host on a
|
||||
// peer/host/domain target when DirectUpstream is set. Empty Host is
|
||||
// allowed — the lookup fills in the default peer IP / resource address.
|
||||
// Without DirectUpstream the Host value is silently overwritten by
|
||||
// replaceHostByLookup, so we don't validate it (preserves the historical
|
||||
// behaviour where APIs accepted any value and dropped it). Non-empty
|
||||
// Host with DirectUpstream must look like a hostname or IP and must
|
||||
// not carry a port (port lives on Target.Port).
|
||||
func validateDirectUpstreamHost(idx int, target *Target) error {
|
||||
if !target.Options.DirectUpstream {
|
||||
return nil
|
||||
}
|
||||
host := strings.TrimSpace(target.Host)
|
||||
if host == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.ContainsAny(host, " \t/") {
|
||||
return fmt.Errorf("target %d: host %q contains invalid characters", idx, host)
|
||||
}
|
||||
if _, _, err := net.SplitHostPort(host); err == nil {
|
||||
return fmt.Errorf("target %d: host %q must not include a port (set target.port instead)", idx, host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateL4Target(target *Target) error {
|
||||
// L4 services have a single target; per-target disable is meaningless
|
||||
// (use the service-level Enabled flag instead). Force it on so that
|
||||
// buildPathMappings always includes the target in the proto.
|
||||
target.Enabled = true
|
||||
|
||||
if target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
if target.TargetId == "" {
|
||||
return errors.New("target_id is required for L4 services")
|
||||
}
|
||||
if target.TargetType != TargetTypeCluster && target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||
// OK
|
||||
if err := validateDirectUpstreamHost(0, target); err != nil {
|
||||
return err
|
||||
}
|
||||
case TargetTypeSubnet:
|
||||
if target.Host == "" {
|
||||
return errors.New("target host is required for subnet targets")
|
||||
}
|
||||
case TargetTypeCluster:
|
||||
// target_id carries the cluster address; the proxy resolves
|
||||
// the upstream at request time.
|
||||
default:
|
||||
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
|
||||
}
|
||||
@@ -1174,6 +1281,11 @@ func (s *Service) Copy() *Service {
|
||||
}
|
||||
}
|
||||
|
||||
var accessGroups []string
|
||||
if len(s.AccessGroups) > 0 {
|
||||
accessGroups = append([]string(nil), s.AccessGroups...)
|
||||
}
|
||||
|
||||
return &Service{
|
||||
ID: s.ID,
|
||||
AccountID: s.AccountID,
|
||||
@@ -1195,6 +1307,8 @@ func (s *Service) Copy() *Service {
|
||||
Mode: s.Mode,
|
||||
ListenPort: s.ListenPort,
|
||||
PortAutoAssigned: s.PortAutoAssigned,
|
||||
Private: s.Private,
|
||||
AccessGroups: accessGroups,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -1116,3 +1117,191 @@ func TestValidate_HeaderAuths(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "exceeds maximum length")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidate_HTTPClusterTarget(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate(), "HTTP cluster target with target_id, host, and direct_upstream must validate")
|
||||
}
|
||||
|
||||
func TestValidate_HTTPClusterTarget_RequiresTargetId(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "empty target_id", "cluster target must reject empty target_id")
|
||||
}
|
||||
|
||||
// TestValidate_HTTPClusterTarget_RequiresHost pins the new cluster-target
|
||||
// rule that operator-supplied Host is mandatory: cluster targets dial the
|
||||
// upstream via the host network stack (direct_upstream is implied), so an
|
||||
// empty Host leaves the proxy with nothing to dial.
|
||||
func TestValidate_HTTPClusterTarget_RequiresHost(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "empty host", "cluster target must reject empty host")
|
||||
}
|
||||
|
||||
// TestValidate_HTTPClusterTarget_RequiresDirectUpstream pins the second
|
||||
// half of the cluster-target rule: DirectUpstream must be true so the
|
||||
// stdlib transport branch in MultiTransport is taken. Without it the
|
||||
// embedded NetBird client would try to dial the cluster address through
|
||||
// the WG tunnel, which is the wrong network for a cluster upstream.
|
||||
func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
|
||||
}
|
||||
|
||||
func TestValidate_L4ClusterTarget(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Mode = ModeTCP
|
||||
rp.ListenPort = 9000
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "tcp",
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
|
||||
}
|
||||
|
||||
func TestService_Copy_RoundtripsPrivate(t *testing.T) {
|
||||
svc := validProxy()
|
||||
svc.Private = true
|
||||
svc.AccessGroups = []string{"grp-admins", "grp-ops"}
|
||||
cp := svc.Copy()
|
||||
require.NotNil(t, cp)
|
||||
assert.True(t, cp.Private)
|
||||
assert.Equal(t, []string{"grp-admins", "grp-ops"}, cp.AccessGroups)
|
||||
|
||||
cp.Private = false
|
||||
assert.True(t, svc.Private)
|
||||
|
||||
cp.AccessGroups[0] = "grp-other"
|
||||
assert.Equal(t, []string{"grp-admins", "grp-ops"}, svc.AccessGroups)
|
||||
}
|
||||
|
||||
func TestService_APIRoundtrip_Private(t *testing.T) {
|
||||
enabled := true
|
||||
private := true
|
||||
accessGroups := []string{"grp-admins"}
|
||||
targets := []api.ServiceTarget{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: api.ServiceTargetTargetType("cluster"),
|
||||
Protocol: "http",
|
||||
Port: 80,
|
||||
Enabled: true,
|
||||
}}
|
||||
req := &api.ServiceRequest{
|
||||
Name: "svc-private",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
Enabled: enabled,
|
||||
Private: &private,
|
||||
AccessGroups: &accessGroups,
|
||||
Targets: &targets,
|
||||
}
|
||||
|
||||
svc := &Service{}
|
||||
require.NoError(t, svc.FromAPIRequest(req, "acc-1"))
|
||||
assert.True(t, svc.Private)
|
||||
assert.Equal(t, []string{"grp-admins"}, svc.AccessGroups)
|
||||
|
||||
resp := svc.ToAPIResponse()
|
||||
require.NotNil(t, resp.Private)
|
||||
assert.True(t, *resp.Private)
|
||||
require.NotNil(t, resp.AccessGroups)
|
||||
assert.Equal(t, []string{"grp-admins"}, *resp.AccessGroups)
|
||||
}
|
||||
|
||||
func TestValidate_Private_RequiresAccessGroups(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "access group")
|
||||
}
|
||||
|
||||
func TestValidate_Private_RejectsBearerAuth(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Auth.BearerAuth = &BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"grp-sso"},
|
||||
}
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "mutually exclusive")
|
||||
}
|
||||
|
||||
func TestValidate_Private_AcceptsNonClusterTargets(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_Private_AcceptsClusterTargetWithAccessGroups(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "backend.lan",
|
||||
Options: TargetOptions{DirectUpstream: true},
|
||||
Enabled: true,
|
||||
}}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_Private_RejectsNonHTTPMode(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Private = true
|
||||
rp.AccessGroups = []string{"grp-admins"}
|
||||
rp.Mode = ModeTCP
|
||||
rp.Targets = []*Target{{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: TargetTypeCluster,
|
||||
Protocol: "tcp",
|
||||
Enabled: true,
|
||||
}}
|
||||
assert.ErrorContains(t, rp.Validate(), "HTTP")
|
||||
}
|
||||
|
||||
@@ -20,6 +20,20 @@ type KeyPair struct {
|
||||
type Claims struct {
|
||||
jwt.RegisteredClaims
|
||||
Method auth.Method `json:"method"`
|
||||
// Email is the calling user's email address. Carried so the
|
||||
// proxy can stamp identity on upstream requests (e.g.
|
||||
// x-litellm-end-user-id) without an extra management
|
||||
// round-trip on every cookie-bearing request.
|
||||
Email string `json:"email,omitempty"`
|
||||
// Groups carries the user's group IDs so the proxy can stamp them
|
||||
// onto upstream requests (X-NetBird-Groups) from the cookie path
|
||||
// without an extra management round-trip.
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
// GroupNames carries the human-readable display names for the ids
|
||||
// in Groups, ordered identically (positional pairing). Slice may be
|
||||
// shorter than Groups for tokens minted before names were
|
||||
// resolvable; the consumer falls back to ids for missing positions.
|
||||
GroupNames []string `json:"group_names,omitempty"`
|
||||
}
|
||||
|
||||
func GenerateKeyPair() (*KeyPair, error) {
|
||||
@@ -34,7 +48,13 @@ func GenerateKeyPair() (*KeyPair, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
||||
// SignToken mints a session JWT for the given user and domain. email,
|
||||
// groups, and groupNames, when non-empty, are embedded so the proxy can
|
||||
// authorise and stamp identity for policy-aware middlewares without a
|
||||
// management round-trip on every cookie-bearing request. groupNames
|
||||
// pairs positionally with groups; pass nil when names couldn't be
|
||||
// resolved.
|
||||
func SignToken(privKeyB64, userID, email, domain string, method auth.Method, groups, groupNames []string, expiration time.Duration) (string, error) {
|
||||
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode private key: %w", err)
|
||||
@@ -56,7 +76,10 @@ func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
Method: method,
|
||||
Method: method,
|
||||
Email: email,
|
||||
Groups: append([]string(nil), groups...),
|
||||
GroupNames: append([]string(nil), groupNames...),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||
|
||||
@@ -351,6 +351,7 @@ func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
Private: c.Private,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -754,6 +755,11 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
// Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService).
|
||||
connUpdate = filterMappingsForProxy(conn, connUpdate)
|
||||
if connUpdate == nil || len(connUpdate.Mapping) == 0 {
|
||||
return true
|
||||
}
|
||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||
if resp == nil {
|
||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
@@ -882,16 +888,20 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
}
|
||||
}
|
||||
|
||||
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
|
||||
// Old proxies that never reported capabilities are skipped for non-TLS L4
|
||||
// mappings with a custom listen port, since they don't understand the
|
||||
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
|
||||
// are new enough to handle the mapping. TLS uses SNI routing and works on
|
||||
// any proxy. Delete operations are always sent so proxies can clean up.
|
||||
// proxyAcceptsMapping returns whether the proxy can receive this mapping.
|
||||
// Private mappings require SupportsPrivateService; custom-port L4 mappings
|
||||
// require SupportsCustomPorts. Remove operations always pass so proxies can
|
||||
// clean up.
|
||||
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
|
||||
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||
return true
|
||||
}
|
||||
if mapping.GetPrivate() {
|
||||
caps := conn.capabilities
|
||||
if caps == nil || caps.SupportsPrivateService == nil || !*caps.SupportsPrivateService {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
|
||||
return true
|
||||
}
|
||||
@@ -900,6 +910,29 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
|
||||
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
|
||||
}
|
||||
|
||||
// filterMappingsForProxy drops mappings the proxy cannot safely receive
|
||||
// (e.g. private mappings to a proxy without SupportsPrivateService).
|
||||
// Returns the input unchanged when no filtering is needed.
|
||||
func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
||||
if update == nil || len(update.Mapping) == 0 {
|
||||
return update
|
||||
}
|
||||
kept := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||
for _, m := range update.Mapping {
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, m)
|
||||
}
|
||||
if len(kept) == len(update.Mapping) {
|
||||
return update
|
||||
}
|
||||
return &proto.GetMappingUpdateResponse{
|
||||
Mapping: kept,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
@@ -961,7 +994,10 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
|
||||
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
|
||||
|
||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method)
|
||||
// Non-OIDC schemes (PIN/Password/Header) authenticate against per-service
|
||||
// secrets and have no user-level group context, so groups stay nil. Email
|
||||
// is also empty — these schemes don't resolve a user record at sign time.
|
||||
token, err := s.generateSessionToken(ctx, authenticated, service, userId, "", method, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1050,7 +1086,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId, userEmail string, method proxyauth.Method, groupIDs, groupNames []string) (string, error) {
|
||||
if !authenticated || service.SessionPrivateKey == "" {
|
||||
return "", nil
|
||||
}
|
||||
@@ -1058,8 +1094,11 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
token, err := sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userId,
|
||||
userEmail,
|
||||
service.Domain,
|
||||
method,
|
||||
groupIDs,
|
||||
groupNames,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -1070,6 +1109,26 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// pairGroupIDsAndNames splits a slice of resolved *types.Group records
|
||||
// into parallel id and name slices. ids[i] and names[i] always pair to
|
||||
// the same group. nil entries (orphan ids the manager couldn't resolve)
|
||||
// are skipped so the consumer can rely on positional pairing.
|
||||
func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) {
|
||||
if len(groups) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
ids = make([]string, 0, len(groups))
|
||||
names = make([]string, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
if g == nil {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, g.ID)
|
||||
names = append(names, g.Name)
|
||||
}
|
||||
return ids, names
|
||||
}
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients.
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
@@ -1334,7 +1393,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
return verifier, redirectURL, nil
|
||||
}
|
||||
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and
|
||||
// user. The user's group memberships are embedded in the token so policy-aware
|
||||
// middlewares on the proxy can authorise without an extra management round-trip.
|
||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
@@ -1345,11 +1406,29 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
|
||||
return "", fmt.Errorf("no session key configured for domain: %s", domain)
|
||||
}
|
||||
|
||||
var (
|
||||
email string
|
||||
groupIDs []string
|
||||
groupNames []string
|
||||
)
|
||||
if s.usersManager != nil {
|
||||
user, userGroups, uerr := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||
if uerr != nil {
|
||||
log.WithContext(ctx).Debugf("session token mint: lookup user %s: %v", userID, uerr)
|
||||
} else if user != nil {
|
||||
email = user.Email
|
||||
groupIDs, groupNames = pairGroupIDsAndNames(userGroups)
|
||||
}
|
||||
}
|
||||
|
||||
return sessionkey.SignToken(
|
||||
service.SessionPrivateKey,
|
||||
userID,
|
||||
email,
|
||||
domain,
|
||||
method,
|
||||
groupIDs,
|
||||
groupNames,
|
||||
proxyauth.DefaultSessionExpiry,
|
||||
)
|
||||
}
|
||||
@@ -1453,7 +1532,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
||||
userID, _, _, _, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
@@ -1466,7 +1545,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
user, err := s.usersManager.GetUser(ctx, userID)
|
||||
user, userGroups, err := s.usersManager.GetUserWithGroups(ctx, userID)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
@@ -1500,12 +1579,15 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
"user_id": userID,
|
||||
"error": err.Error(),
|
||||
}).Debug("ValidateSession: access denied")
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateSessionResponse{
|
||||
Valid: false,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
DeniedReason: "not_in_group",
|
||||
Valid: false,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
DeniedReason: "not_in_group",
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1515,10 +1597,13 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
"email": user.Email,
|
||||
}).Debug("ValidateSession: access granted")
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
|
||||
return &proto.ValidateSessionResponse{
|
||||
Valid: true,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
Valid: true,
|
||||
UserId: user.Id,
|
||||
UserEmail: user.Email,
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1551,3 +1636,154 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T { return &v }
|
||||
|
||||
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
|
||||
// checks the peer's group membership against the service's access groups.
|
||||
// Peers without a user (machine agents, automation workloads) are first-class
|
||||
// callers; authorisation runs off peer-group memberships rather than the
|
||||
// optional owning user's auto-groups. On success a session JWT is minted so
|
||||
// the proxy can install a cookie and skip subsequent management round-trips.
|
||||
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
domain := req.GetDomain()
|
||||
tunnelIPStr := req.GetTunnelIp()
|
||||
|
||||
if domain == "" || tunnelIPStr == "" {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "missing domain or tunnel_ip",
|
||||
}, nil
|
||||
}
|
||||
|
||||
tunnelIP := net.ParseIP(tunnelIPStr)
|
||||
if tunnelIP == nil {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "invalid_tunnel_ip",
|
||||
}, nil
|
||||
}
|
||||
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "service_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Mirror ValidateSession: account-scoped (BYOP) proxy tokens may only
|
||||
// validate and mint session cookies for their own account's domains.
|
||||
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
|
||||
if err != nil || peer == nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "peer_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
DeniedReason: "peer_not_found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||
|
||||
// Resolve the principal: when the peer is linked to a user, the human
|
||||
// is the principal so multiple peers owned by the same user share a
|
||||
// single identity. Unlinked peers (machine agents) are their own
|
||||
// principal keyed on peer.ID. displayIdentity is what upstream gateways
|
||||
// tag spend with — user.Email when linked, peer.Name when not.
|
||||
principalID := peer.ID
|
||||
displayIdentity := peer.Name
|
||||
if peer.UserID != "" {
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
||||
//nolint:nilerr
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: false,
|
||||
UserId: principalID,
|
||||
UserEmail: displayIdentity,
|
||||
DeniedReason: "not_in_group",
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"domain": domain,
|
||||
"tunnel_ip": tunnelIPStr,
|
||||
"peer_id": peer.ID,
|
||||
"principal_id": principalID,
|
||||
}).Debug("ValidateTunnelPeer: access granted")
|
||||
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: true,
|
||||
UserId: principalID,
|
||||
UserEmail: displayIdentity,
|
||||
SessionToken: token,
|
||||
PeerGroupIds: groupIDs,
|
||||
PeerGroupNames: groupNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
||||
// groups. Private services authorise against AccessGroups (empty list fails
|
||||
// closed — Validate() rejects that at save time but the RPC is the security
|
||||
// boundary and must not trust upstream state). Bearer-auth services authorise
|
||||
// against DistributionGroups when populated. Non-private non-bearer services
|
||||
// are open.
|
||||
func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error {
|
||||
if service.Private {
|
||||
if len(service.AccessGroups) == 0 {
|
||||
return fmt.Errorf("private service has no access groups")
|
||||
}
|
||||
return matchAnyGroup(service.AccessGroups, peerGroupIDs)
|
||||
}
|
||||
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 {
|
||||
return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups,
|
||||
// else a non-nil error.
|
||||
func matchAnyGroup(allowedGroups, peerGroupIDs []string) error {
|
||||
if len(allowedGroups) == 0 {
|
||||
return fmt.Errorf("no allowed groups configured")
|
||||
}
|
||||
allowed := make(map[string]struct{}, len(allowedGroups))
|
||||
for _, g := range allowedGroups {
|
||||
allowed[g] = struct{}{}
|
||||
}
|
||||
for _, g := range peerGroupIDs {
|
||||
if _, ok := allowed[g]; ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("peer not in allowed groups")
|
||||
}
|
||||
|
||||
@@ -129,6 +129,14 @@ func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.U
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||
user, err := m.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return user, nil, nil
|
||||
}
|
||||
|
||||
func TestValidateUserGroupAccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -420,3 +428,46 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPeerGroupAccess(t *testing.T) {
|
||||
t.Run("private with empty AccessGroups denies", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: nil}
|
||||
err := checkPeerGroupAccess(svc, []string{"grp-admins"})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no access groups")
|
||||
})
|
||||
|
||||
t.Run("private with peer in AccessGroups allows", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins", "grp-ops"}}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-other", "grp-ops"}))
|
||||
})
|
||||
|
||||
t.Run("private with peer outside AccessGroups denies", func(t *testing.T) {
|
||||
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins"}}
|
||||
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||
})
|
||||
|
||||
t.Run("bearer enabled with empty DistributionGroups allows", func(t *testing.T) {
|
||||
svc := &service.Service{
|
||||
Auth: service.AuthConfig{BearerAuth: &service.BearerAuthConfig{Enabled: true}},
|
||||
}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-anyone"}))
|
||||
})
|
||||
|
||||
t.Run("bearer enabled gates on DistributionGroups", func(t *testing.T) {
|
||||
svc := &service.Service{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"grp-allowed"},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-allowed"}))
|
||||
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
|
||||
})
|
||||
|
||||
t.Run("non-private non-bearer is open", func(t *testing.T) {
|
||||
assert.NoError(t, checkPeerGroupAccess(&service.Service{}, nil))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
|
||||
|
||||
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
||||
t.Helper()
|
||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
|
||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
return token
|
||||
}
|
||||
@@ -125,6 +125,7 @@ func TestValidateSession_UserAllowed(t *testing.T) {
|
||||
assert.True(t, resp.Valid, "User should be allowed access")
|
||||
assert.Equal(t, "allowedUserId", resp.UserId)
|
||||
assert.Empty(t, resp.DeniedReason)
|
||||
assert.Equal(t, []string{"allowedGroupId"}, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's group memberships")
|
||||
}
|
||||
|
||||
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||
@@ -145,6 +146,7 @@ func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||
assert.False(t, resp.Valid, "User not in group should be denied")
|
||||
assert.Equal(t, "not_in_group", resp.DeniedReason)
|
||||
assert.Equal(t, "nonGroupUserId", resp.UserId)
|
||||
assert.Empty(t, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's actual (empty) memberships on denial")
|
||||
}
|
||||
|
||||
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
)
|
||||
@@ -53,6 +54,7 @@ type DataSource interface {
|
||||
GetAllAccounts(ctx context.Context) []*types.Account
|
||||
GetStoreEngine() types.Engine
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
GetProxyMetrics(ctx context.Context) (store.ProxyMetrics, error)
|
||||
}
|
||||
|
||||
// ConnManager peer connection manager that holds state for current active connections
|
||||
@@ -223,6 +225,12 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
servicesAuthPassword int
|
||||
servicesAuthPin int
|
||||
servicesAuthOIDC int
|
||||
// Private-service signals — track adoption of NetBird-only mode
|
||||
// (services backed by an embedded proxy peer + access groups).
|
||||
servicesPrivate int
|
||||
servicesPrivateWithGroups int
|
||||
servicesPrivateAccessGroupsSum int
|
||||
servicesWithDirectUpstream int
|
||||
)
|
||||
start := time.Now()
|
||||
metricsProperties := make(properties)
|
||||
@@ -380,9 +388,31 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
|
||||
servicesAuthOIDC++
|
||||
}
|
||||
|
||||
if service.Private {
|
||||
servicesPrivate++
|
||||
if len(service.AccessGroups) > 0 {
|
||||
servicesPrivateWithGroups++
|
||||
}
|
||||
servicesPrivateAccessGroupsSum += len(service.AccessGroups)
|
||||
}
|
||||
|
||||
for _, target := range service.Targets {
|
||||
if target.Options.DirectUpstream {
|
||||
servicesWithDirectUpstream++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Proxy / BYOP cluster signals come from the proxies table aggregated
|
||||
// across all accounts in a single store query; nil on FileStore.
|
||||
proxyMetrics, err := w.dataSource.GetProxyMetrics(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("collect proxy metrics: %v", err)
|
||||
}
|
||||
|
||||
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
|
||||
metricsProperties["uptime"] = uptime
|
||||
metricsProperties["accounts"] = accounts
|
||||
@@ -430,6 +460,15 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
metricsProperties["services_auth_password"] = servicesAuthPassword
|
||||
metricsProperties["services_auth_pin"] = servicesAuthPin
|
||||
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
|
||||
metricsProperties["services_private"] = servicesPrivate
|
||||
metricsProperties["services_private_with_access_groups"] = servicesPrivateWithGroups
|
||||
metricsProperties["services_private_access_groups_sum"] = servicesPrivateAccessGroupsSum
|
||||
metricsProperties["services_with_direct_upstream"] = servicesWithDirectUpstream
|
||||
metricsProperties["proxy_clusters"] = proxyMetrics.Clusters
|
||||
metricsProperties["proxy_clusters_byop"] = proxyMetrics.ClustersBYOP
|
||||
metricsProperties["proxy_clusters_private"] = proxyMetrics.ClustersPrivate
|
||||
metricsProperties["proxies"] = proxyMetrics.Proxies
|
||||
metricsProperties["proxies_connected"] = proxyMetrics.ProxiesConnected
|
||||
metricsProperties["custom_domains"] = customDomains
|
||||
metricsProperties["custom_domains_validated"] = customDomainsValidated
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -123,7 +124,7 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: "peer"},
|
||||
{TargetType: "host"},
|
||||
{TargetType: "host", Options: rpservice.TargetOptions{DirectUpstream: true}},
|
||||
},
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
|
||||
@@ -141,6 +142,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
},
|
||||
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
|
||||
},
|
||||
{
|
||||
ID: "svc3-private",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
AccessGroups: []string{"grp-eng", "grp-ops"},
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: "cluster", Options: rpservice.TargetOptions{DirectUpstream: true}},
|
||||
},
|
||||
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -254,6 +265,18 @@ func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, e
|
||||
return 3, 2, nil
|
||||
}
|
||||
|
||||
// GetProxyMetrics returns canned proxy/cluster counts so the
|
||||
// generateProperties test can assert the BYOP signals end-to-end.
|
||||
func (mockDatasource) GetProxyMetrics(_ context.Context) (store.ProxyMetrics, error) {
|
||||
return store.ProxyMetrics{
|
||||
Clusters: 3,
|
||||
ClustersBYOP: 1,
|
||||
ClustersPrivate: 1,
|
||||
Proxies: 4,
|
||||
ProxiesConnected: 2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
|
||||
func TestGenerateProperties(t *testing.T) {
|
||||
ds := mockDatasource{}
|
||||
@@ -393,17 +416,17 @@ func TestGenerateProperties(t *testing.T) {
|
||||
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
|
||||
}
|
||||
|
||||
if properties["services"] != 2 {
|
||||
t.Errorf("expected 2 services, got %v", properties["services"])
|
||||
if properties["services"] != 3 {
|
||||
t.Errorf("expected 3 services, got %v", properties["services"])
|
||||
}
|
||||
if properties["services_enabled"] != 1 {
|
||||
t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"])
|
||||
if properties["services_enabled"] != 2 {
|
||||
t.Errorf("expected 2 services_enabled, got %v", properties["services_enabled"])
|
||||
}
|
||||
if properties["services_targets"] != 3 {
|
||||
t.Errorf("expected 3 services_targets, got %v", properties["services_targets"])
|
||||
if properties["services_targets"] != 4 {
|
||||
t.Errorf("expected 4 services_targets, got %v", properties["services_targets"])
|
||||
}
|
||||
if properties["services_status_active"] != 1 {
|
||||
t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"])
|
||||
if properties["services_status_active"] != 2 {
|
||||
t.Errorf("expected 2 services_status_active, got %v", properties["services_status_active"])
|
||||
}
|
||||
if properties["services_status_pending"] != 1 {
|
||||
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
|
||||
@@ -420,6 +443,9 @@ func TestGenerateProperties(t *testing.T) {
|
||||
if properties["services_target_type_domain"] != 1 {
|
||||
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
|
||||
}
|
||||
if properties["services_target_type_cluster"] != 1 {
|
||||
t.Errorf("expected 1 services_target_type_cluster, got %v", properties["services_target_type_cluster"])
|
||||
}
|
||||
if properties["services_auth_password"] != 1 {
|
||||
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
|
||||
}
|
||||
@@ -429,6 +455,33 @@ func TestGenerateProperties(t *testing.T) {
|
||||
if properties["services_auth_pin"] != 0 {
|
||||
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
|
||||
}
|
||||
if properties["services_private"] != 1 {
|
||||
t.Errorf("expected 1 services_private, got %v", properties["services_private"])
|
||||
}
|
||||
if properties["services_private_with_access_groups"] != 1 {
|
||||
t.Errorf("expected 1 services_private_with_access_groups, got %v", properties["services_private_with_access_groups"])
|
||||
}
|
||||
if properties["services_private_access_groups_sum"] != 2 {
|
||||
t.Errorf("expected 2 services_private_access_groups_sum, got %v", properties["services_private_access_groups_sum"])
|
||||
}
|
||||
if properties["services_with_direct_upstream"] != 2 {
|
||||
t.Errorf("expected 2 services_with_direct_upstream, got %v", properties["services_with_direct_upstream"])
|
||||
}
|
||||
if properties["proxy_clusters"] != int64(3) {
|
||||
t.Errorf("expected 3 proxy_clusters, got %v", properties["proxy_clusters"])
|
||||
}
|
||||
if properties["proxy_clusters_byop"] != int64(1) {
|
||||
t.Errorf("expected 1 proxy_clusters_byop, got %v", properties["proxy_clusters_byop"])
|
||||
}
|
||||
if properties["proxy_clusters_private"] != int64(1) {
|
||||
t.Errorf("expected 1 proxy_clusters_private, got %v", properties["proxy_clusters_private"])
|
||||
}
|
||||
if properties["proxies"] != int64(4) {
|
||||
t.Errorf("expected 4 proxies, got %v", properties["proxies"])
|
||||
}
|
||||
if properties["proxies_connected"] != int64(2) {
|
||||
t.Errorf("expected 2 proxies_connected, got %v", properties["proxies_connected"])
|
||||
}
|
||||
if properties["custom_domains"] != int64(3) {
|
||||
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
|
||||
}
|
||||
|
||||
@@ -125,6 +125,18 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
}
|
||||
|
||||
// An embedded proxy peer flipping to connected is the trigger for
|
||||
// SynthesizePrivateServiceZones to emit DNS A records pointing at its
|
||||
// tunnel IP. Without an account-wide netmap recompute, user peers keep
|
||||
// the stale synth (or no synth at all on first connect) until some
|
||||
// other change pokes the controller. Fire OnPeersUpdated so the
|
||||
// buffered recompute fans the new state out to every peer.
|
||||
if peer.ProxyMeta.Embedded {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s connect: %v", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -160,6 +172,17 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
|
||||
return nil
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
|
||||
|
||||
// Symmetric with MarkPeerConnected: when an embedded proxy peer goes
|
||||
// offline, drive an account-wide netmap recompute so the synthesized
|
||||
// DNS records that pointed at it are pulled. Without this the records
|
||||
// linger client-side at TTL until something else triggers a refresh.
|
||||
if peer.ProxyMeta.Embedded {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s disconnect: %v", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -274,3 +274,9 @@ func (s *FileStore) SetFieldEncrypt(_ *crypt.FieldEncrypt) {
|
||||
func (s *FileStore) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) {
|
||||
return 0, 0, nil
|
||||
}
|
||||
|
||||
// GetProxyMetrics is a no-op for FileStore — proxy/cluster state isn't
|
||||
// persisted in the JSON file format.
|
||||
func (s *FileStore) GetProxyMetrics(_ context.Context) (ProxyMetrics, error) {
|
||||
return ProxyMetrics{}, nil
|
||||
}
|
||||
|
||||
@@ -1090,6 +1090,38 @@ func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, er
|
||||
return total, validated, nil
|
||||
}
|
||||
|
||||
// GetProxyMetrics aggregates per-cluster + per-proxy counts for the
|
||||
// self-hosted telemetry payload. Single round-trip via conditional
|
||||
// aggregations so a large proxies table doesn't fan out into multiple
|
||||
// queries.
|
||||
func (s *SqlStore) GetProxyMetrics(ctx context.Context) (ProxyMetrics, error) {
|
||||
var m ProxyMetrics
|
||||
activeCutoff := time.Now().Add(-proxyActiveThreshold)
|
||||
|
||||
// COUNT(DISTINCT ... CASE WHEN ...) is portable across sqlite/postgres
|
||||
// (MySQL too) and keeps the round-trip to one. proxy.StatusConnected
|
||||
// is the same string the cluster-capability queries use; the active
|
||||
// window matches the cluster-capability semantics (only proxies
|
||||
// heartbeating within ~2 * heartbeat interval count as connected).
|
||||
row := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Select(
|
||||
"COUNT(DISTINCT cluster_address) AS clusters, "+
|
||||
"COUNT(DISTINCT CASE WHEN account_id IS NOT NULL THEN cluster_address END) AS clusters_byop, "+
|
||||
"COUNT(DISTINCT CASE WHEN private = ? THEN cluster_address END) AS clusters_private, "+
|
||||
"COUNT(*) AS proxies, "+
|
||||
"COUNT(CASE WHEN status = ? AND last_seen > ? THEN 1 END) AS proxies_connected",
|
||||
true,
|
||||
proxy.StatusConnected,
|
||||
activeCutoff,
|
||||
).
|
||||
Row()
|
||||
if err := row.Scan(&m.Clusters, &m.ClustersBYOP, &m.ClustersPrivate, &m.Proxies, &m.ProxiesConnected); err != nil {
|
||||
return ProxyMetrics{}, fmt.Errorf("scan proxy metrics: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
|
||||
var accounts []types.Account
|
||||
result := s.db.Find(&accounts)
|
||||
@@ -2178,7 +2210,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
||||
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key,
|
||||
mode, listen_port, port_auto_assigned, source, source_peer, terminated
|
||||
mode, listen_port, port_auto_assigned, source, source_peer, terminated,
|
||||
private, access_groups
|
||||
FROM services WHERE account_id = $1`
|
||||
|
||||
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
|
||||
@@ -2193,10 +2226,11 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
|
||||
var s rpservice.Service
|
||||
var auth []byte
|
||||
var accessGroups []byte
|
||||
var createdAt, certIssuedAt sql.NullTime
|
||||
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
||||
var mode, source, sourcePeer sql.NullString
|
||||
var terminated, portAutoAssigned sql.NullBool
|
||||
var terminated, portAutoAssigned, private sql.NullBool
|
||||
var listenPort sql.NullInt64
|
||||
err := row.Scan(
|
||||
&s.ID,
|
||||
@@ -2219,6 +2253,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
&source,
|
||||
&sourcePeer,
|
||||
&terminated,
|
||||
&private,
|
||||
&accessGroups,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2230,6 +2266,16 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
}
|
||||
}
|
||||
|
||||
if len(accessGroups) > 0 {
|
||||
if err := json.Unmarshal(accessGroups, &s.AccessGroups); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal access_groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if private.Valid {
|
||||
s.Private = private.Bool
|
||||
}
|
||||
|
||||
s.Meta = rpservice.Meta{}
|
||||
if createdAt.Valid {
|
||||
s.Meta.CreatedAt = createdAt.Time
|
||||
@@ -5826,6 +5872,7 @@ var validCapabilityColumns = map[string]struct{}{
|
||||
"supports_custom_ports": {},
|
||||
"require_subdomain": {},
|
||||
"supports_crowdsec": {},
|
||||
"private": {},
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||
@@ -5840,6 +5887,12 @@ func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr s
|
||||
return s.getClusterCapability(ctx, clusterAddr, "require_subdomain")
|
||||
}
|
||||
|
||||
// GetClusterSupportsPrivate reports whether any active proxy in the cluster
|
||||
// has the private capability (nil = unreported).
|
||||
func (s *SqlStore) GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
return s.getClusterCapability(ctx, clusterAddr, "private")
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec returns whether all active proxies in the cluster
|
||||
// have CrowdSec configured. Returns nil when no proxy reported the capability.
|
||||
// Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec
|
||||
@@ -5908,7 +5961,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
|
||||
AnyTrue bool
|
||||
}
|
||||
|
||||
err := s.db.
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
|
||||
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").
|
||||
|
||||
46
management/server/store/sql_store_service_test.go
Normal file
46
management/server/store/sql_store_service_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
)
|
||||
|
||||
func TestSqlStore_GetAccount_PrivateServiceRoundtrip(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
ctx := context.Background()
|
||||
account := newAccountWithId(ctx, "account_private_svc", "testuser", "")
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
svc := &rpservice.Service{
|
||||
ID: "svc-private",
|
||||
AccountID: account.Id,
|
||||
Name: "private-svc",
|
||||
Domain: "private.example",
|
||||
ProxyCluster: "cluster.example",
|
||||
Enabled: true,
|
||||
Mode: rpservice.ModeHTTP,
|
||||
Private: true,
|
||||
AccessGroups: []string{"grp-admins", "grp-ops"},
|
||||
}
|
||||
require.NoError(t, store.CreateService(ctx, svc))
|
||||
|
||||
loaded, err := store.GetAccount(ctx, account.Id)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, loaded.Services, 1)
|
||||
|
||||
got := loaded.Services[0]
|
||||
assert.True(t, got.Private)
|
||||
assert.Equal(t, []string{"grp-admins", "grp-ops"}, got.AccessGroups)
|
||||
})
|
||||
}
|
||||
@@ -312,6 +312,7 @@ type Store interface {
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
@@ -320,9 +321,38 @@ type Store interface {
|
||||
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
|
||||
// GetProxyMetrics returns aggregated proxy / cluster counts for the
|
||||
// self-hosted metrics worker. Self-hosted only — file-based stores
|
||||
// return a zero-valued struct.
|
||||
GetProxyMetrics(ctx context.Context) (ProxyMetrics, error)
|
||||
|
||||
GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error)
|
||||
}
|
||||
|
||||
// ProxyMetrics aggregates self-hosted proxy + cluster usage signals
|
||||
// surfaced to the telemetry payload. Each field is best-effort: when a
|
||||
// store cannot answer (e.g. FileStore) all fields are zero.
|
||||
type ProxyMetrics struct {
|
||||
// Clusters counts distinct cluster_address values across the proxies
|
||||
// table — every cluster the management server has heard from, online or not.
|
||||
Clusters int64
|
||||
// ClustersBYOP counts distinct cluster_address values that are owned
|
||||
// by an account (account_id IS NOT NULL). These are bring-your-own-proxy
|
||||
// installations as opposed to NetBird-operated shared clusters.
|
||||
ClustersBYOP int64
|
||||
// ClustersPrivate counts distinct cluster_address values where at
|
||||
// least one proxy reported the private capability (embedded
|
||||
// `netbird proxy` running inside a client).
|
||||
ClustersPrivate int64
|
||||
// Proxies is the total number of proxy rows currently persisted.
|
||||
Proxies int64
|
||||
// ProxiesConnected is the subset of proxies whose status is
|
||||
// "connected" AND last_seen falls within the active heartbeat window
|
||||
// (~2 * heartbeat interval). Proxies the controller hasn't pruned
|
||||
// yet but that are visibly stale don't count.
|
||||
ProxiesConnected int64
|
||||
}
|
||||
|
||||
const (
|
||||
postgresDsnEnv = "NB_STORE_ENGINE_POSTGRES_DSN"
|
||||
postgresDsnEnvLegacy = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
|
||||
@@ -1461,6 +1461,20 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetClusterSupportsPrivate mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsPrivate", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterSupportsPrivate indicates an expected call of GetClusterSupportsPrivate.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsPrivate", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsPrivate), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetCustomDomain mocks base method.
|
||||
func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2076,6 +2090,21 @@ func (mr *MockStoreMockRecorder) GetProxyClusters(ctx, accountID interface{}) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyClusters", reflect.TypeOf((*MockStore)(nil).GetProxyClusters), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetProxyMetrics mocks base method.
|
||||
func (m *MockStore) GetProxyMetrics(ctx context.Context) (ProxyMetrics, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyMetrics", ctx)
|
||||
ret0, _ := ret[0].(ProxyMetrics)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyMetrics indicates an expected call of GetProxyMetrics.
|
||||
func (mr *MockStoreMockRecorder) GetProxyMetrics(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyMetrics", reflect.TypeOf((*MockStore)(nil).GetProxyMetrics), ctx)
|
||||
}
|
||||
|
||||
// GetResourceGroups mocks base method.
|
||||
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -32,7 +32,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTTL = 300
|
||||
defaultTTL = 300
|
||||
// privateServiceDNSRecordTTL is short so proxy-peer changes propagate quickly to clients.
|
||||
privateServiceDNSRecordTTL = 5
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
DefaultPeerInactivityExpiration = 10 * time.Minute
|
||||
|
||||
@@ -254,6 +256,117 @@ func getUniqueHostLabel(name string, peerLabels LookupMap) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// SynthesizePrivateServiceZones returns in-memory CustomZones with A records pointing each enabled private service the peer can reach at the cluster's proxy-peer IPs. One zone per cluster (multiple services share); records gated by AccessGroups.
|
||||
func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZone {
|
||||
peer, ok := a.Peers[peerID]
|
||||
if !ok || peer == nil {
|
||||
return nil
|
||||
}
|
||||
if len(a.Services) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
proxyPeersByCluster := a.GetProxyPeers()
|
||||
if len(proxyPeersByCluster) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
zonesByCluster := map[string]*nbdns.CustomZone{}
|
||||
|
||||
for _, svc := range a.Services {
|
||||
if svc == nil || !svc.Enabled || !svc.Private {
|
||||
continue
|
||||
}
|
||||
if len(svc.AccessGroups) == 0 {
|
||||
continue
|
||||
}
|
||||
if !peerInDistributionGroups(peerGroups, svc.AccessGroups) {
|
||||
continue
|
||||
}
|
||||
proxyPeers := proxyPeersByCluster[svc.ProxyCluster]
|
||||
if len(proxyPeers) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
zone, exists := zonesByCluster[svc.ProxyCluster]
|
||||
if !exists {
|
||||
// NonAuthoritative makes this a match-only zone: queries for
|
||||
// names without an explicit record fall through to the
|
||||
// upstream resolver instead of returning NXDOMAIN. Without
|
||||
// it, adding a single private service would black-hole every
|
||||
// other name under the cluster apex.
|
||||
zone = &nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(svc.ProxyCluster),
|
||||
Records: []nbdns.SimpleRecord{},
|
||||
NonAuthoritative: true,
|
||||
}
|
||||
zonesByCluster[svc.ProxyCluster] = zone
|
||||
}
|
||||
|
||||
emitted := 0
|
||||
skippedDisconnected := 0
|
||||
for _, p := range proxyPeers {
|
||||
if p == nil || !p.IP.IsValid() {
|
||||
continue
|
||||
}
|
||||
// Only emit a record when the proxy peer is actually
|
||||
// connected. A disconnected proxy peer's tunnel IP won't
|
||||
// answer; pointing DNS at it would produce a black hole
|
||||
// for as long as the record is cached client-side.
|
||||
if p.Status == nil || !p.Status.Connected {
|
||||
skippedDisconnected++
|
||||
continue
|
||||
}
|
||||
zone.Records = append(zone.Records, nbdns.SimpleRecord{
|
||||
Name: dns.Fqdn(svc.Domain),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: privateServiceDNSRecordTTL,
|
||||
RData: p.IP.String(),
|
||||
})
|
||||
emitted++
|
||||
}
|
||||
// Disagreement with the firewall path is the typical
|
||||
// "domain doesn't reach client but firewall rules do"
|
||||
// symptom: the synth service is otherwise fine, only the
|
||||
// proxy peer's persisted Connected flag is wrong (most
|
||||
// likely the connection reaper marked it disconnected even
|
||||
// though the gRPC stream is alive).
|
||||
if emitted == 0 && skippedDisconnected > 0 {
|
||||
log.Debugf("private-zone synth: svc %s domain=%s cluster=%s emitted_zero proxy_peers=%d all_disconnected=%d (firewall would still fire)",
|
||||
svc.ID, svc.Domain, svc.ProxyCluster, len(proxyPeers), skippedDisconnected)
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
|
||||
for _, zone := range zonesByCluster {
|
||||
if len(zone.Records) == 0 {
|
||||
continue
|
||||
}
|
||||
out = append(out, *zone)
|
||||
}
|
||||
if len(out) == 0 && len(a.Services) > 0 {
|
||||
// Targeted diagnostic for the "firewall yes, DNS no" divergence —
|
||||
// fires only when services exist but synth returns zero zones,
|
||||
// so accounts without private services produce no noise.
|
||||
log.Debugf("private-zone synth: peer %s account %s returned 0 zones from %d candidate service(s)",
|
||||
peerID, a.Id, len(a.Services))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// peerInDistributionGroups reports whether any of the peer's groups
|
||||
// matches the service's bearer-auth distribution_groups.
|
||||
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {
|
||||
for _, gid := range distributionGroups {
|
||||
if _, ok := peerGroups[gid]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
|
||||
var merr *multierror.Error
|
||||
|
||||
@@ -1498,6 +1611,53 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *servi
|
||||
a.injectTargetProxyPolicies(ctx, service, target, proxyPeers)
|
||||
}
|
||||
|
||||
a.injectPrivateServicePolicies(service, proxyPeers)
|
||||
}
|
||||
|
||||
// injectPrivateServicePolicies synthesises an in-memory ACL: AccessGroups → cluster proxy peers on TCP 80/443.
|
||||
func (a *Account) injectPrivateServicePolicies(svc *service.Service, proxyPeers []*nbpeer.Peer) {
|
||||
if !svc.Private {
|
||||
return
|
||||
}
|
||||
if len(svc.AccessGroups) == 0 {
|
||||
return
|
||||
}
|
||||
if len(proxyPeers) == 0 {
|
||||
return
|
||||
}
|
||||
for _, proxyPeer := range proxyPeers {
|
||||
a.Policies = append(a.Policies, a.createPrivateServicePolicy(svc, proxyPeer))
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) createPrivateServicePolicy(svc *service.Service, proxyPeer *nbpeer.Peer) *Policy {
|
||||
policyID := fmt.Sprintf("private-access-%s-%s", svc.ID, proxyPeer.ID)
|
||||
sources := append([]string(nil), svc.AccessGroups...)
|
||||
return &Policy{
|
||||
ID: policyID,
|
||||
Name: fmt.Sprintf("Private Access to %s", svc.Name),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: policyID,
|
||||
PolicyID: policyID,
|
||||
Name: fmt.Sprintf("Allow access groups to reach %s", svc.Name),
|
||||
Enabled: true,
|
||||
Sources: sources,
|
||||
DestinationResource: Resource{
|
||||
ID: proxyPeer.ID,
|
||||
Type: ResourceTypePeer,
|
||||
},
|
||||
Bidirectional: false,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
{Start: 443, End: 443},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
|
||||
|
||||
@@ -119,6 +119,7 @@ func (a *Account) GetPeerNetworkMapComponents(
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||
components.AccountZones = append(components.AccountZones, a.SynthesizePrivateServiceZones(peerID)...)
|
||||
|
||||
for _, nsGroup := range a.NameServerGroups {
|
||||
if nsGroup.Enabled {
|
||||
|
||||
85
management/server/types/account_private_netmap_test.go
Normal file
85
management/server/types/account_private_netmap_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func TestPrivateService_NetworkMap_UserPeer_AndProxyPeer(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["user-peer"].Meta.WtVersion = "0.50.0"
|
||||
account.Peers["proxy-peer"].Meta.WtVersion = "0.50.0"
|
||||
|
||||
ctx := context.Background()
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
validated := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
}
|
||||
|
||||
t.Run("user-peer update", func(t *testing.T) {
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "user-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
zone, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok)
|
||||
require.Len(t, zone.Records, 1)
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", zone.Records[0].Name)
|
||||
assert.Equal(t, int(dns.TypeA), zone.Records[0].Type)
|
||||
assert.Equal(t, "100.64.0.99", zone.Records[0].RData)
|
||||
|
||||
assert.Contains(t, netmapPeerIDs(nm.Peers), "proxy-peer")
|
||||
assertPrivateServiceFirewallRules(t, nm.FirewallRules, "100.64.0.99", FirewallRuleDirectionOUT)
|
||||
})
|
||||
|
||||
t.Run("proxy-peer update", func(t *testing.T) {
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "proxy-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
assert.Contains(t, netmapPeerIDs(nm.Peers), "user-peer")
|
||||
assertPrivateServiceFirewallRules(t, nm.FirewallRules, "100.64.0.10", FirewallRuleDirectionIN)
|
||||
})
|
||||
}
|
||||
|
||||
func netmapPeerIDs(peers []*nbpeer.Peer) []string {
|
||||
ids := make([]string, 0, len(peers))
|
||||
for _, p := range peers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, p.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func assertPrivateServiceFirewallRules(t *testing.T, rules []*FirewallRule, peerIP string, direction int) {
|
||||
t.Helper()
|
||||
wantPorts := map[uint16]bool{80: false, 443: false}
|
||||
for _, r := range rules {
|
||||
if r == nil || r.PeerIP != peerIP || r.Direction != direction {
|
||||
continue
|
||||
}
|
||||
if r.Protocol != string(PolicyRuleProtocolTCP) || r.Action != string(PolicyTrafficActionAccept) {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case r.PortRange.Start == r.PortRange.End && r.PortRange.Start != 0:
|
||||
wantPorts[r.PortRange.Start] = true
|
||||
case r.Port == "80":
|
||||
wantPorts[80] = true
|
||||
case r.Port == "443":
|
||||
wantPorts[443] = true
|
||||
}
|
||||
}
|
||||
for port, found := range wantPorts {
|
||||
assert.Truef(t, found, "missing TCP accept rule on port %d for peer %s direction %d", port, peerIP, direction)
|
||||
}
|
||||
}
|
||||
256
management/server/types/account_private_zones_test.go
Normal file
256
management/server/types/account_private_zones_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func privateZoneTestAccount(t *testing.T) *Account {
|
||||
t.Helper()
|
||||
return &Account{
|
||||
Id: "acct-1",
|
||||
Settings: &Settings{},
|
||||
Network: &Network{
|
||||
Identifier: "net-1",
|
||||
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"user-peer": {
|
||||
ID: "user-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "user-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.10"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
},
|
||||
"proxy-peer": {
|
||||
ID: "proxy-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.99"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
ProxyMeta: nbpeer.ProxyMeta{
|
||||
Embedded: true,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"grp-admins": {
|
||||
ID: "grp-admins",
|
||||
Name: "admins",
|
||||
Peers: []string{"user-peer"},
|
||||
},
|
||||
},
|
||||
Services: []*service.Service{
|
||||
{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "myapp",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_PeerInGroup_GetsRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "one cluster should produce one zone")
|
||||
zone := zones[0]
|
||||
assert.Equal(t, "eu.proxy.netbird.io.", zone.Domain, "zone apex must be the cluster FQDN")
|
||||
assert.True(t, zone.NonAuthoritative, "synth zone must be match-only so unrelated sibling names fall through to the upstream resolver")
|
||||
require.Len(t, zone.Records, 1, "one private service yields one A record")
|
||||
rec := zone.Records[0]
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", rec.Name, "record name is the service FQDN")
|
||||
assert.Equal(t, int(dns.TypeA), rec.Type, "record type must be A")
|
||||
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
|
||||
assert.Equal(t, privateServiceDNSRecordTTL, rec.TTL, "TTL must match the synth-records constant")
|
||||
assert.Equal(t, nbdns.DefaultClass, rec.Class, "record class must be the package default")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_PeerNotInGroup_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Groups["grp-admins"].Peers = nil
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "peer outside distribution_groups must not see private-service records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_NotPrivate_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services[0].Private = false
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "non-private service must not produce DNS records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_NoAccessGroups_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services[0].AccessGroups = nil
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "private service without bearer auth must not produce DNS records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_NoProxyPeers_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
delete(account.Peers, "proxy-peer")
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "no embedded proxy peer in cluster means no record to emit")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_DisabledService_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services[0].Enabled = false
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "disabled service must not produce DNS records")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_DisconnectedProxyPeer_NoRecord(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["proxy-peer"].Status = &nbpeer.PeerStatus{Connected: false}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
assert.Empty(t, zones, "disconnected proxy peer must not produce a DNS record (would be a black hole)")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_PartiallyDisconnectedProxyPeers_OnlyConnectedSurface(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["proxy-peer-2"] = &nbpeer.Peer{
|
||||
ID: "proxy-peer-2",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-2-key",
|
||||
IP: netip.MustParseAddr("100.64.0.100"),
|
||||
Status: &nbpeer.PeerStatus{Connected: false},
|
||||
ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "eu.proxy.netbird.io"},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1)
|
||||
require.Len(t, zones[0].Records, 1, "only the connected proxy peer must surface")
|
||||
assert.Equal(t, "100.64.0.99", zones[0].Records[0].RData)
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_MultipleProxyPeers_RoundRobin(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["proxy-peer-2"] = &nbpeer.Peer{
|
||||
ID: "proxy-peer-2",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-2-key",
|
||||
IP: netip.MustParseAddr("100.64.0.100"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "eu.proxy.netbird.io"},
|
||||
}
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "still one cluster yields one zone")
|
||||
require.Len(t, zones[0].Records, 2, "two proxy peers must produce two A records on the same name")
|
||||
rdata := []string{zones[0].Records[0].RData, zones[0].Records[1].RData}
|
||||
assert.ElementsMatch(t, []string{"100.64.0.99", "100.64.0.100"}, rdata, "both proxy peer IPs must surface")
|
||||
}
|
||||
|
||||
// findCustomZone returns the CustomZone whose Domain equals the FQDN
|
||||
// of want, or a zero value when not found. Tests use it to assert
|
||||
// that the synth zone reaches dnsUpdate.CustomZones end-to-end.
|
||||
func findCustomZone(zones []nbdns.CustomZone, want string) (nbdns.CustomZone, bool) {
|
||||
wantFqdn := dns.Fqdn(want)
|
||||
for _, z := range zones {
|
||||
if z.Domain == wantFqdn {
|
||||
return z, true
|
||||
}
|
||||
}
|
||||
return nbdns.CustomZone{}, false
|
||||
}
|
||||
|
||||
// TestPrivateZone_GetPeerNetworkMapFromComponents_ShipsSynthZone
|
||||
// covers the components-based builder path. The components builder
|
||||
// appends SynthesizePrivateServiceZones to AccountZones; the
|
||||
// CalculateNetworkMapFromComponents step then merges AccountZones
|
||||
// into dnsUpdate.CustomZones.
|
||||
func TestPrivateZone_GetPeerNetworkMapFromComponents_ShipsSynthZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
ctx := context.Background()
|
||||
validated := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
}
|
||||
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "user-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm, "network map must be produced for an in-account peer")
|
||||
|
||||
zone, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
|
||||
require.True(t, ok, "shipped CustomZones must include the synth zone for the cluster")
|
||||
require.Len(t, zone.Records, 1, "exactly one record per private service per connected proxy peer")
|
||||
rec := zone.Records[0]
|
||||
assert.Equal(t, "myapp.eu.proxy.netbird.io.", rec.Name, "record name is the service FQDN")
|
||||
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
|
||||
}
|
||||
|
||||
// TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone
|
||||
// confirms the negative case the user encountered: a peer whose
|
||||
// groups don't overlap the policy's distribution_groups gets a
|
||||
// network map with no synth zone (and the wildcard / peer zones still
|
||||
// flow through). This is the test mirror of the runtime confusion
|
||||
// where the user looked at a non-distribution-group peer and assumed
|
||||
// the synth path was broken.
|
||||
func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Peers["outsider"] = &nbpeer.Peer{
|
||||
ID: "outsider",
|
||||
AccountID: "acct-1",
|
||||
Key: "outsider-key",
|
||||
IP: netip.MustParseAddr("100.64.0.20"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
}
|
||||
ctx := context.Background()
|
||||
validated := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
"outsider": {},
|
||||
}
|
||||
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, "outsider", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
_, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io")
|
||||
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
|
||||
}
|
||||
|
||||
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
|
||||
account := privateZoneTestAccount(t)
|
||||
account.Services = append(account.Services, &service.Service{
|
||||
ID: "svc-2",
|
||||
AccountID: "acct-1",
|
||||
Name: "anotherapp",
|
||||
Domain: "anotherapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
})
|
||||
|
||||
zones := account.SynthesizePrivateServiceZones("user-peer")
|
||||
require.Len(t, zones, 1, "two services on the same cluster must collapse into one zone")
|
||||
require.Len(t, zones[0].Records, 2, "two services yield two A records")
|
||||
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
|
||||
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package types
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -82,9 +84,9 @@ func setupTestAccount() *Account {
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"groupAll": {
|
||||
ID: "groupAll",
|
||||
Name: "All",
|
||||
Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"},
|
||||
ID: "groupAll",
|
||||
Name: "All",
|
||||
Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"},
|
||||
Issued: GroupIssuedAPI,
|
||||
},
|
||||
"group1": {
|
||||
@@ -1583,3 +1585,203 @@ func Test_filterPeerAppliedZones(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_ProxyPeerGetsInboundRule(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
userPeerIP := netip.MustParseAddr("100.64.0.10")
|
||||
proxyPeerIP := netip.MustParseAddr("100.64.0.99")
|
||||
|
||||
account := &Account{
|
||||
Id: "acct-1",
|
||||
Network: &Network{
|
||||
Identifier: "net-1",
|
||||
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"user-peer": {
|
||||
ID: "user-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "user-peer-key",
|
||||
IP: userPeerIP,
|
||||
},
|
||||
"proxy-peer": {
|
||||
ID: "proxy-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-key",
|
||||
IP: proxyPeerIP,
|
||||
ProxyMeta: nbpeer.ProxyMeta{
|
||||
Embedded: true,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"grp-admins": {
|
||||
ID: "grp-admins",
|
||||
Name: "admins",
|
||||
Peers: []string{"user-peer"},
|
||||
},
|
||||
},
|
||||
Services: []*service.Service{
|
||||
{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "myapp",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
Targets: []*service.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: service.TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
var found *Policy
|
||||
for _, p := range account.Policies {
|
||||
if p != nil && p.ID == "private-access-svc-1-proxy-peer" {
|
||||
found = p
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, found, "expected synthesised private-access policy in account.Policies")
|
||||
require.Len(t, found.Rules, 1, "policy should have exactly one rule")
|
||||
rule := found.Rules[0]
|
||||
assert.Equal(t, []string{"grp-admins"}, rule.Sources, "sources should be group IDs verbatim")
|
||||
assert.Equal(t, "proxy-peer", rule.DestinationResource.ID, "destination resource should be the proxy peer ID")
|
||||
assert.Equal(t, ResourceTypePeer, rule.DestinationResource.Type, "destination resource type should be peer")
|
||||
|
||||
validatedPeersMap := map[string]struct{}{
|
||||
"user-peer": {},
|
||||
"proxy-peer": {},
|
||||
}
|
||||
|
||||
proxyPeer := account.Peers["proxy-peer"]
|
||||
aclPeers, firewallRules, _, _ := account.GetPeerConnectionResources(ctx, proxyPeer, validatedPeersMap, nil)
|
||||
|
||||
var sawUserAsAclPeer bool
|
||||
for _, p := range aclPeers {
|
||||
if p.ID == "user-peer" {
|
||||
sawUserAsAclPeer = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, sawUserAsAclPeer, "proxy peer should see the user peer as an ACL peer")
|
||||
|
||||
var inboundRules []*FirewallRule
|
||||
for _, r := range firewallRules {
|
||||
if r.Direction == FirewallRuleDirectionIN && r.PeerIP == userPeerIP.String() {
|
||||
inboundRules = append(inboundRules, r)
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, inboundRules, "proxy peer should have inbound firewall rules from the user peer")
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_NotPrivate_NoPolicy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
account := privateServiceTestAccount(t)
|
||||
account.Services[0].Private = false
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "non-private service must not synthesise an access policy")
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_EmptyAccessGroups_NoPolicy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
account := privateServiceTestAccount(t)
|
||||
account.Services[0].AccessGroups = nil
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "private service with no access groups must not synthesise a policy")
|
||||
}
|
||||
|
||||
func TestInjectPrivateServicePolicies_NoProxyPeers_NoPolicy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
account := privateServiceTestAccount(t)
|
||||
delete(account.Peers, "proxy-peer")
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "policy must not synthesise when the cluster has no proxy peers")
|
||||
}
|
||||
|
||||
func privateServiceTestAccount(t *testing.T) *Account {
|
||||
t.Helper()
|
||||
return &Account{
|
||||
Id: "acct-1",
|
||||
Network: &Network{
|
||||
Identifier: "net-1",
|
||||
Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"user-peer": {
|
||||
ID: "user-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "user-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.10"),
|
||||
},
|
||||
"proxy-peer": {
|
||||
ID: "proxy-peer",
|
||||
AccountID: "acct-1",
|
||||
Key: "proxy-peer-key",
|
||||
IP: netip.MustParseAddr("100.64.0.99"),
|
||||
ProxyMeta: nbpeer.ProxyMeta{
|
||||
Embedded: true,
|
||||
Cluster: "eu.proxy.netbird.io",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"grp-admins": {
|
||||
ID: "grp-admins",
|
||||
Name: "admins",
|
||||
Peers: []string{"user-peer"},
|
||||
},
|
||||
},
|
||||
Services: []*service.Service{
|
||||
{
|
||||
ID: "svc-1",
|
||||
AccountID: "acct-1",
|
||||
Name: "myapp",
|
||||
Domain: "myapp.eu.proxy.netbird.io",
|
||||
ProxyCluster: "eu.proxy.netbird.io",
|
||||
Enabled: true,
|
||||
Private: true,
|
||||
Mode: service.ModeHTTP,
|
||||
AccessGroups: []string{"grp-admins"},
|
||||
Targets: []*service.Target{
|
||||
{
|
||||
TargetId: "eu.proxy.netbird.io",
|
||||
TargetType: service.TargetTypeCluster,
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func hasPrivateAccessPolicy(account *Account, serviceID string) bool {
|
||||
prefix := "private-access-" + serviceID + "-"
|
||||
for _, p := range account.Policies {
|
||||
if p != nil && len(p.ID) > len(prefix) && p.ID[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
type Manager interface {
|
||||
GetUser(ctx context.Context, userID string) (*types.User, error)
|
||||
GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -29,6 +30,31 @@ func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User,
|
||||
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
}
|
||||
|
||||
// GetUserWithGroups returns the user and the *types.Group records for the user's AutoGroups, in the same order as
|
||||
// AutoGroups. Group ids that don't resolve to a stored group are skipped from the returned slice (the parallel id list is
|
||||
// derivable from the returned User). Wraps two store calls today; can be optimised to a single JOIN later if needed.
|
||||
// Any store error returns (nil, nil, err) so callers never receive a valid user alongside a non-nil error.
|
||||
func (m *managerImpl) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(user.AutoGroups) == 0 {
|
||||
return user, nil, nil
|
||||
}
|
||||
groupsMap, err := m.store.GetGroupsByIDs(ctx, store.LockingStrengthNone, user.AccountID, user.AutoGroups)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
groups := make([]*types.Group, 0, len(user.AutoGroups))
|
||||
for _, id := range user.AutoGroups {
|
||||
if g, ok := groupsMap[id]; ok && g != nil {
|
||||
groups = append(groups, g)
|
||||
}
|
||||
}
|
||||
return user, groups, nil
|
||||
}
|
||||
|
||||
func NewManagerMock() Manager {
|
||||
return &managerMock{}
|
||||
}
|
||||
@@ -47,3 +73,11 @@ func (m *managerMock) GetUser(ctx context.Context, userID string) (*types.User,
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerMock) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
|
||||
user, err := m.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return user, nil, nil
|
||||
}
|
||||
|
||||
@@ -45,10 +45,14 @@ func ResolveProto(forwardedProto string, conn *tls.ConnectionState) string {
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateSessionJWT validates a session JWT and returns the user ID and method.
|
||||
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, method string, err error) {
|
||||
// ValidateSessionJWT validates a session JWT and returns the user ID, the
|
||||
// user's email (when carried), the authentication method, any embedded
|
||||
// group memberships, and the parallel group display names. email,
|
||||
// groups, and groupNames may be empty for tokens minted before those
|
||||
// claims were introduced. groupNames pairs positionally with groups.
|
||||
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, email, method string, groups, groupNames []string, err error) {
|
||||
if publicKey == nil {
|
||||
return "", "", fmt.Errorf("no public key configured for domain")
|
||||
return "", "", "", nil, nil, fmt.Errorf("no public key configured for domain")
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
|
||||
@@ -58,20 +62,46 @@ func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey)
|
||||
return publicKey, nil
|
||||
}, jwt.WithAudience(domain), jwt.WithIssuer(SessionJWTIssuer))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("parse token: %w", err)
|
||||
return "", "", "", nil, nil, fmt.Errorf("parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", "", fmt.Errorf("invalid token claims")
|
||||
return "", "", "", nil, nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
sub, _ := claims.GetSubject()
|
||||
if sub == "" {
|
||||
return "", "", fmt.Errorf("missing subject claim")
|
||||
return "", "", "", nil, nil, fmt.Errorf("missing subject claim")
|
||||
}
|
||||
|
||||
methodClaim, _ := claims["method"].(string)
|
||||
emailClaim, _ := claims["email"].(string)
|
||||
groups = extractGroupsClaim(claims["groups"])
|
||||
groupNames = extractGroupsClaim(claims["group_names"])
|
||||
|
||||
return sub, methodClaim, nil
|
||||
return sub, emailClaim, methodClaim, groups, groupNames, nil
|
||||
}
|
||||
|
||||
// extractGroupsClaim decodes the "groups" claim into a string slice. The JWT
|
||||
// library decodes JSON arrays as []interface{}, so we coerce element-wise
|
||||
// and skip non-string entries silently.
|
||||
func extractGroupsClaim(claim interface{}) []string {
|
||||
raw, ok := claim.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
groups := make([]string, 0, len(raw))
|
||||
for _, v := range raw {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
groups = append(groups, s)
|
||||
}
|
||||
}
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
@@ -63,6 +63,7 @@ var (
|
||||
preSharedKey string
|
||||
supportsCustomPorts bool
|
||||
requireSubdomain bool
|
||||
private bool
|
||||
geoDataDir string
|
||||
crowdsecAPIURL string
|
||||
crowdsecAPIKey string
|
||||
@@ -105,6 +106,12 @@ func init() {
|
||||
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
|
||||
rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough")
|
||||
rootCmd.Flags().BoolVar(&requireSubdomain, "require-subdomain", envBoolOrDefault("NB_PROXY_REQUIRE_SUBDOMAIN", false), "Require a subdomain label in front of the cluster domain")
|
||||
// --private is internal: set by the embedded `netbird proxy` subcommand
|
||||
// via NB_PROXY_PRIVATE so management can distinguish per-peer / private
|
||||
// clusters from centralised ones. Hidden so the standalone CLI doesn't
|
||||
// surface it as an operator-facing toggle.
|
||||
rootCmd.Flags().BoolVar(&private, "private", envBoolOrDefault("NB_PROXY_PRIVATE", false), "Mark this proxy as embedded/private (internal flag)")
|
||||
_ = rootCmd.Flags().MarkHidden("private")
|
||||
rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)")
|
||||
rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)")
|
||||
rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)")
|
||||
@@ -161,7 +168,8 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("invalid --trusted-proxies: %w", err)
|
||||
}
|
||||
|
||||
srv := proxy.Server{
|
||||
srv := proxy.New(proxy.Config{
|
||||
ListenAddr: addr,
|
||||
Logger: logger,
|
||||
Version: Version,
|
||||
ManagementAddress: mgmtAddr,
|
||||
@@ -178,7 +186,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
ACMEChallengeType: acmeChallengeType,
|
||||
DebugEndpointEnabled: debugEndpoint,
|
||||
DebugEndpointAddress: debugEndpointAddr,
|
||||
HealthAddress: healthAddr,
|
||||
HealthAddr: healthAddr,
|
||||
ForwardedProto: forwardedProto,
|
||||
TrustedProxies: parsedTrustedProxies,
|
||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||
@@ -188,12 +196,13 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
PreSharedKey: preSharedKey,
|
||||
SupportsCustomPorts: supportsCustomPorts,
|
||||
RequireSubdomain: requireSubdomain,
|
||||
Private: private,
|
||||
MaxDialTimeout: maxDialTimeout,
|
||||
MaxSessionIdleTimeout: maxSessionIdleTimeout,
|
||||
GeoDataDir: geoDataDir,
|
||||
CrowdSecAPIURL: crowdsecAPIURL,
|
||||
CrowdSecAPIKey: crowdsecAPIKey,
|
||||
}
|
||||
})
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
defer stop()
|
||||
|
||||
537
proxy/inbound.go
Normal file
537
proxy/inbound.go
Normal file
@@ -0,0 +1,537 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
stdlog "log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
// httpInboundReadHeaderTimeout matches the host-listener read header timeout
|
||||
// so per-account http.Servers don't leak idle connections.
|
||||
const httpInboundReadHeaderTimeout = 30 * time.Second
|
||||
|
||||
// httpInboundIdleTimeout caps idle keep-alive on per-account inbound HTTP
|
||||
// servers; matches the host listener.
|
||||
const httpInboundIdleTimeout = 90 * time.Second
|
||||
|
||||
// inboundShutdownTimeout caps how long a per-account http.Server gets to
|
||||
// drain in-flight requests during teardown.
|
||||
const inboundShutdownTimeout = 5 * time.Second
|
||||
|
||||
// privateInboundPortHTTPS is the WG-side TLS port. Each account's
|
||||
// embedded netstack binds independently, so a fixed port is fine.
|
||||
const privateInboundPortHTTPS = 443
|
||||
|
||||
// privateInboundPortHTTP is the WG-side plain-HTTP port.
|
||||
const privateInboundPortHTTP = 80
|
||||
|
||||
// inboundManager wires per-account inbound listeners into the proxy
|
||||
// pipeline when --private-inbound is enabled. When disabled the manager
|
||||
// is nil and every method on *Server that touches it short-circuits.
|
||||
type inboundManager struct {
|
||||
logger *log.Logger
|
||||
handler http.Handler
|
||||
tlsConfig *tls.Config
|
||||
// muxLock guards entries and pendingRoutes.
|
||||
muxLock sync.Mutex
|
||||
entries map[types.AccountID]*inboundEntry
|
||||
pendingRoutes map[types.AccountID][]pendingInboundRoute
|
||||
}
|
||||
|
||||
// inboundEntry owns the listeners, router and HTTP servers for a single
|
||||
// account's embedded netstack.
|
||||
type inboundEntry struct {
|
||||
router *nbtcp.Router
|
||||
tlsListener net.Listener
|
||||
plainListener net.Listener
|
||||
httpsServer *http.Server
|
||||
httpServer *http.Server
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// pendingInboundRoute holds a route that arrived before the account's
|
||||
// listener finished starting.
|
||||
type pendingInboundRoute struct {
|
||||
host nbtcp.SNIHost
|
||||
route nbtcp.Route
|
||||
}
|
||||
|
||||
// newInboundManager constructs a manager bound to the proxy's HTTP
|
||||
// handler chain and TLS config.
|
||||
func newInboundManager(logger *log.Logger, handler http.Handler, tlsConfig *tls.Config) *inboundManager {
|
||||
return &inboundManager{
|
||||
logger: logger,
|
||||
handler: handler,
|
||||
tlsConfig: tlsConfig,
|
||||
entries: make(map[types.AccountID]*inboundEntry),
|
||||
pendingRoutes: make(map[types.AccountID][]pendingInboundRoute),
|
||||
}
|
||||
}
|
||||
|
||||
// onClientReady is registered with NetBird.SetClientLifecycle so the
|
||||
// listener pair comes up exactly when the embedded client reports ready.
|
||||
// The returned value is opaque to the roundtrip package; it is handed
|
||||
// back verbatim to onClientStop on teardown.
|
||||
func (m *inboundManager) onClientReady(ctx context.Context, accountID types.AccountID, client *embed.Client) any {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
entry, err := m.bringUp(ctx, accountID, client)
|
||||
if err != nil {
|
||||
m.logger.WithField("account_id", accountID).WithError(err).Warn("failed to start per-account inbound listener; continuing without inbound")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.flushPending(accountID, entry)
|
||||
|
||||
m.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"https": entry.tlsListener.Addr().String(),
|
||||
"http": entry.plainListener.Addr().String(),
|
||||
}).Info("per-account inbound listeners up")
|
||||
return entry
|
||||
}
|
||||
|
||||
// onClientStop tears down a per-account listener bundle. State is the
|
||||
// opaque value previously returned by onClientReady.
|
||||
func (m *inboundManager) onClientStop(accountID types.AccountID, state any) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
entry, ok := state.(*inboundEntry)
|
||||
if !ok || entry == nil {
|
||||
return
|
||||
}
|
||||
m.tearDown(accountID, entry)
|
||||
}
|
||||
|
||||
// bringUp opens both listeners on the account's netstack, builds the
|
||||
// router, and starts the parallel HTTP servers.
|
||||
func (m *inboundManager) bringUp(ctx context.Context, accountID types.AccountID, client *embed.Client) (*inboundEntry, error) {
|
||||
tlsListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTPS))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen tls on netstack: %w", err)
|
||||
}
|
||||
plainListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTP))
|
||||
if err != nil {
|
||||
_ = tlsListener.Close()
|
||||
return nil, fmt.Errorf("listen plain on netstack: %w", err)
|
||||
}
|
||||
|
||||
router := nbtcp.NewRouter(m.logger, accountDialResolver(accountID, client), tlsListener.Addr(), nbtcp.WithPlainHTTP(plainListener.Addr()))
|
||||
|
||||
scopedHandler := withTunnelLookup(m.handler, accountTunnelLookup(client))
|
||||
|
||||
httpsServer := &http.Server{
|
||||
Handler: scopedHandler,
|
||||
TLSConfig: m.tlsConfig,
|
||||
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
|
||||
IdleTimeout: httpInboundIdleTimeout,
|
||||
ErrorLog: newInboundErrorLog(m.logger, "https", accountID),
|
||||
}
|
||||
httpServer := &http.Server{
|
||||
Handler: scopedHandler,
|
||||
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
|
||||
IdleTimeout: httpInboundIdleTimeout,
|
||||
ErrorLog: newInboundErrorLog(m.logger, "http", accountID),
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
entry := &inboundEntry{
|
||||
router: router,
|
||||
tlsListener: tlsListener,
|
||||
plainListener: plainListener,
|
||||
httpsServer: httpsServer,
|
||||
httpServer: httpServer,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
entry.wg.Add(1)
|
||||
go func() {
|
||||
defer entry.wg.Done()
|
||||
if err := router.Serve(runCtx, tlsListener); err != nil {
|
||||
m.logger.WithField("account_id", accountID).Debugf("per-account router stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
entry.wg.Add(1)
|
||||
go func() {
|
||||
defer entry.wg.Done()
|
||||
if err := httpsServer.ServeTLS(router.HTTPListener(), "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
m.logger.WithField("account_id", accountID).Debugf("per-account https server stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
entry.wg.Add(1)
|
||||
go func() {
|
||||
defer entry.wg.Done()
|
||||
if err := httpServer.Serve(router.HTTPListenerPlain()); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
m.logger.WithField("account_id", accountID).Debugf("per-account http server stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
entry.wg.Add(1)
|
||||
go func() {
|
||||
defer entry.wg.Done()
|
||||
feedRouterFromListener(runCtx, plainListener, router, m.logger, accountID)
|
||||
}()
|
||||
|
||||
m.muxLock.Lock()
|
||||
m.entries[accountID] = entry
|
||||
m.muxLock.Unlock()
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// tearDown shuts every goroutine down and closes the netstack listeners.
|
||||
func (m *inboundManager) tearDown(accountID types.AccountID, entry *inboundEntry) {
|
||||
m.muxLock.Lock()
|
||||
if m.entries[accountID] == entry {
|
||||
delete(m.entries, accountID)
|
||||
delete(m.pendingRoutes, accountID)
|
||||
}
|
||||
m.muxLock.Unlock()
|
||||
|
||||
entry.cancel()
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), inboundShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := entry.httpsServer.Shutdown(shutdownCtx); err != nil {
|
||||
m.logger.Debugf("per-account https shutdown: %v", err)
|
||||
}
|
||||
if err := entry.httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
m.logger.Debugf("per-account http shutdown: %v", err)
|
||||
}
|
||||
if err := entry.tlsListener.Close(); err != nil {
|
||||
m.logger.Debugf("close per-account tls listener: %v", err)
|
||||
}
|
||||
if err := entry.plainListener.Close(); err != nil {
|
||||
m.logger.Debugf("close per-account plain listener: %v", err)
|
||||
}
|
||||
entry.wg.Wait()
|
||||
}
|
||||
|
||||
// AddRoute records an SNI/host route on the account's per-account router.
|
||||
// Routes registered before the listener is up are queued and replayed
|
||||
// once startup completes.
|
||||
func (m *inboundManager) AddRoute(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
entry, ok := m.entries[accountID]
|
||||
if !ok {
|
||||
m.queuePendingLocked(accountID, host, route)
|
||||
m.muxLock.Unlock()
|
||||
return
|
||||
}
|
||||
router := entry.router
|
||||
m.muxLock.Unlock()
|
||||
|
||||
router.AddRoute(host, route)
|
||||
}
|
||||
|
||||
// RemoveRoute drops a previously registered route. Safe to call when the
|
||||
// listener is not yet up; queued copies are pruned in that case.
|
||||
func (m *inboundManager) RemoveRoute(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
m.dropPendingLocked(accountID, host, svcID)
|
||||
entry, ok := m.entries[accountID]
|
||||
if !ok {
|
||||
m.muxLock.Unlock()
|
||||
return
|
||||
}
|
||||
router := entry.router
|
||||
m.muxLock.Unlock()
|
||||
|
||||
router.RemoveRoute(host, svcID)
|
||||
}
|
||||
|
||||
// queuePendingLocked stores or upserts a pending route. Caller holds muxLock.
|
||||
func (m *inboundManager) queuePendingLocked(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) {
|
||||
queued := m.pendingRoutes[accountID]
|
||||
for i, pr := range queued {
|
||||
if pr.host == host && pr.route.ServiceID == route.ServiceID {
|
||||
queued[i] = pendingInboundRoute{host: host, route: route}
|
||||
m.pendingRoutes[accountID] = queued
|
||||
return
|
||||
}
|
||||
}
|
||||
m.pendingRoutes[accountID] = append(queued, pendingInboundRoute{host: host, route: route})
|
||||
}
|
||||
|
||||
// dropPendingLocked removes any queued route matching host/svcID.
|
||||
// Caller holds muxLock.
|
||||
func (m *inboundManager) dropPendingLocked(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) {
|
||||
queued, ok := m.pendingRoutes[accountID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
filtered := queued[:0]
|
||||
for _, pr := range queued {
|
||||
if pr.host == host && pr.route.ServiceID == svcID {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, pr)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(m.pendingRoutes, accountID)
|
||||
return
|
||||
}
|
||||
m.pendingRoutes[accountID] = filtered
|
||||
}
|
||||
|
||||
// flushPending applies all queued routes to a freshly-up router.
|
||||
func (m *inboundManager) flushPending(accountID types.AccountID, entry *inboundEntry) {
|
||||
m.muxLock.Lock()
|
||||
queued := m.pendingRoutes[accountID]
|
||||
delete(m.pendingRoutes, accountID)
|
||||
m.muxLock.Unlock()
|
||||
|
||||
for _, pr := range queued {
|
||||
entry.router.AddRoute(pr.host, pr.route)
|
||||
}
|
||||
}
|
||||
|
||||
// HasInbound reports whether the manager has a live listener for the account.
|
||||
// Used by tests.
|
||||
func (m *inboundManager) HasInbound(accountID types.AccountID) bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
defer m.muxLock.Unlock()
|
||||
_, ok := m.entries[accountID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// PendingRouteCount reports the number of queued routes for the account.
|
||||
// Used by tests.
|
||||
func (m *inboundManager) PendingRouteCount(accountID types.AccountID) int {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
defer m.muxLock.Unlock()
|
||||
return len(m.pendingRoutes[accountID])
|
||||
}
|
||||
|
||||
// InboundListenerInfo describes the bound addresses of a single
|
||||
// per-account inbound listener. Both addresses live on the embedded
|
||||
// netstack of the account's WireGuard client and share the same tunnel IP.
|
||||
type InboundListenerInfo struct {
|
||||
TunnelIP string
|
||||
HTTPSPort uint16
|
||||
HTTPPort uint16
|
||||
}
|
||||
|
||||
// ListenerInfo returns the inbound listener addresses for the given
|
||||
// account, or ok=false when the account has no live listener. Used by
|
||||
// the status-update RPC and the debug HTTP handler to surface inbound
|
||||
// reachability to operators.
|
||||
func (m *inboundManager) ListenerInfo(accountID types.AccountID) (InboundListenerInfo, bool) {
|
||||
if m == nil {
|
||||
return InboundListenerInfo{}, false
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
defer m.muxLock.Unlock()
|
||||
entry, ok := m.entries[accountID]
|
||||
if !ok || entry == nil {
|
||||
return InboundListenerInfo{}, false
|
||||
}
|
||||
return listenerInfoFromEntry(entry), true
|
||||
}
|
||||
|
||||
// Snapshot returns the inbound listener state for every account that has
|
||||
// a live listener at call time. Empty when --private-inbound is off or
|
||||
// no accounts have come up yet.
|
||||
func (m *inboundManager) Snapshot() map[types.AccountID]InboundListenerInfo {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.muxLock.Lock()
|
||||
defer m.muxLock.Unlock()
|
||||
if len(m.entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[types.AccountID]InboundListenerInfo, len(m.entries))
|
||||
for id, entry := range m.entries {
|
||||
if entry == nil {
|
||||
continue
|
||||
}
|
||||
out[id] = listenerInfoFromEntry(entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// listenerInfoFromEntry extracts the tunnel IP and ports from a live
|
||||
// per-account entry. Both listeners are bound on the same netstack so
|
||||
// their host components match; we still pull the TLS host as the
|
||||
// authoritative source.
|
||||
func listenerInfoFromEntry(entry *inboundEntry) InboundListenerInfo {
|
||||
info := InboundListenerInfo{HTTPSPort: privateInboundPortHTTPS, HTTPPort: privateInboundPortHTTP}
|
||||
if entry.tlsListener != nil {
|
||||
host, port := splitHostPort(entry.tlsListener.Addr())
|
||||
info.TunnelIP = host
|
||||
if port != 0 {
|
||||
info.HTTPSPort = port
|
||||
}
|
||||
}
|
||||
if entry.plainListener != nil {
|
||||
host, port := splitHostPort(entry.plainListener.Addr())
|
||||
if info.TunnelIP == "" {
|
||||
info.TunnelIP = host
|
||||
}
|
||||
if port != 0 {
|
||||
info.HTTPPort = port
|
||||
}
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// splitHostPort extracts host and port from a net.Addr, returning the
|
||||
// zero values when the address is missing or malformed.
|
||||
func splitHostPort(addr net.Addr) (string, uint16) {
|
||||
if addr == nil {
|
||||
return "", 0
|
||||
}
|
||||
host, portStr, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
return "", 0
|
||||
}
|
||||
if portStr == "" {
|
||||
return host, 0
|
||||
}
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return host, 0
|
||||
}
|
||||
return host, uint16(port)
|
||||
}
|
||||
|
||||
// feedRouterFromListener accepts on the plain-HTTP netstack listener and
|
||||
// hands every connection to the account's router. The router peeks the
|
||||
// first byte and dispatches to the plain-HTTP channel for non-TLS
|
||||
// streams or the TLS channel for ClientHellos that arrive on :80.
|
||||
func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.Router, logger *log.Logger, accountID types.AccountID) {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = ln.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
|
||||
continue
|
||||
}
|
||||
router.HandleConn(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// accountDialResolver returns a DialResolver bound to a single account's
|
||||
// embedded client. The router only ever serves traffic for that account
|
||||
// so the supplied accountID is ignored at dial time.
|
||||
func accountDialResolver(_ types.AccountID, client *embed.Client) nbtcp.DialResolver {
|
||||
return func(_ types.AccountID) (types.DialContextFunc, error) {
|
||||
return client.DialContext, nil
|
||||
}
|
||||
}
|
||||
|
||||
// accountTunnelLookup returns a TunnelLookupFunc backed by the embedded
|
||||
// client's peerstore for a single account. Phase 3 uses the result to
|
||||
// short-circuit ValidateTunnelPeer when the source IP is not in the
|
||||
// account's roster and to seed the cached identity for known peers.
|
||||
func accountTunnelLookup(client *embed.Client) auth.TunnelLookupFunc {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
return func(ip netip.Addr) (auth.PeerIdentity, bool) {
|
||||
pubKey, fqdn, ok := client.IdentityForIP(ip)
|
||||
if !ok {
|
||||
return auth.PeerIdentity{}, false
|
||||
}
|
||||
return auth.PeerIdentity{
|
||||
PubKey: pubKey,
|
||||
TunnelIP: ip,
|
||||
FQDN: fqdn,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
||||
// withTunnelLookup returns an http.Handler that attaches the per-account
|
||||
// peerstore lookup to every request's context before delegating to next.
|
||||
// Calling on the host-level listener is a no-op because that path never
|
||||
// installs this wrapper, so the existing behaviour stays byte-for-byte
|
||||
// identical when --private-inbound is off or the request didn't arrive
|
||||
// on a per-account listener.
|
||||
func withTunnelLookup(next http.Handler, lookup auth.TunnelLookupFunc) http.Handler {
|
||||
if lookup == nil {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := auth.WithTunnelLookup(r.Context(), lookup)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// inboundDebugAdapter adapts *inboundManager to the debug.InboundProvider
|
||||
// interface so the debug HTTP handler can render per-account inbound
|
||||
// listener state without importing the proxy package.
|
||||
type inboundDebugAdapter struct {
|
||||
mgr *inboundManager
|
||||
}
|
||||
|
||||
// InboundListeners returns a snapshot of the live per-account inbound
|
||||
// listeners formatted for the debug surface.
|
||||
func (a inboundDebugAdapter) InboundListeners() map[types.AccountID]debug.InboundListenerInfo {
|
||||
if a.mgr == nil {
|
||||
return nil
|
||||
}
|
||||
snap := a.mgr.Snapshot()
|
||||
if len(snap) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[types.AccountID]debug.InboundListenerInfo, len(snap))
|
||||
for id, info := range snap {
|
||||
out[id] = debug.InboundListenerInfo{
|
||||
TunnelIP: info.TunnelIP,
|
||||
HTTPSPort: info.HTTPSPort,
|
||||
HTTPPort: info.HTTPPort,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// newInboundErrorLog routes a per-account http.Server's stdlib error
|
||||
// stream through logrus at warn level.
|
||||
func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) *stdlog.Logger {
|
||||
return stdlog.New(logger.WithFields(log.Fields{
|
||||
"inbound-http": scheme,
|
||||
"account_id": accountID,
|
||||
}).WriterLevel(log.WarnLevel), "", 0)
|
||||
}
|
||||
502
proxy/inbound_test.go
Normal file
502
proxy/inbound_test.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// bufioReader wraps the connection in a buffered reader so http.ReadResponse
|
||||
// can parse the response line + headers off the wire.
|
||||
func bufioReader(conn net.Conn) *bufio.Reader {
|
||||
return bufio.NewReader(conn)
|
||||
}
|
||||
|
||||
// quietLogger returns a logger that emits nothing — keeps test output tidy.
|
||||
func quietLogger() *log.Logger {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
return logger
|
||||
}
|
||||
|
||||
func TestInboundManager_RouteScopedToAccount(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
|
||||
accountA := types.AccountID("acct-a")
|
||||
accountB := types.AccountID("acct-b")
|
||||
|
||||
mgr.AddRoute(accountA, "shared.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountA, ServiceID: "svc-a", Domain: "shared.example"})
|
||||
mgr.AddRoute(accountB, "other.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: "other.example"})
|
||||
|
||||
require.Equal(t, 1, mgr.PendingRouteCount(accountA), "account A should have one queued route")
|
||||
require.Equal(t, 1, mgr.PendingRouteCount(accountB), "account B should have one queued route")
|
||||
|
||||
mgr.RemoveRoute(accountA, "shared.example", "svc-a")
|
||||
mgr.RemoveRoute(accountB, "other.example", "svc-b")
|
||||
|
||||
assert.Equal(t, 0, mgr.PendingRouteCount(accountA), "queue should drain on remove")
|
||||
assert.Equal(t, 0, mgr.PendingRouteCount(accountB), "queue should drain on remove")
|
||||
}
|
||||
|
||||
func TestInboundManager_PendingThenFlush(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
|
||||
accountID := types.AccountID("acct-1")
|
||||
host := nbtcp.SNIHost("example.test")
|
||||
route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-1", Domain: "example.test"}
|
||||
|
||||
mgr.AddRoute(accountID, host, route)
|
||||
require.Equal(t, 1, mgr.PendingRouteCount(accountID), "pending count before listener is up")
|
||||
|
||||
// Simulate listener up by registering a fake entry, then flushing.
|
||||
router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
|
||||
entry := &inboundEntry{router: router}
|
||||
mgr.muxLock.Lock()
|
||||
mgr.entries[accountID] = entry
|
||||
mgr.muxLock.Unlock()
|
||||
|
||||
mgr.flushPending(accountID, entry)
|
||||
assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "queue should be empty after flush")
|
||||
}
|
||||
|
||||
// fakeAddr is a stub net.Addr for tests that don't actually bind sockets.
|
||||
type fakeAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a *fakeAddr) Network() string { return "tcp" }
|
||||
func (a *fakeAddr) String() string { return a.addr }
|
||||
|
||||
// fakeMgmtClient implements roundtrip.managementClient for tests.
|
||||
type fakeMgmtClient struct{}
|
||||
|
||||
func (fakeMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// TestServer_PrivateInbound_NotEnabled_NoManager confirms that with
|
||||
// --private off the inbound manager is nil and the standalone proxy
|
||||
// keeps its zero-overhead default path.
|
||||
func TestServer_PrivateInbound_NotEnabled_NoManager(t *testing.T) {
|
||||
s := &Server{Logger: quietLogger(), Private: false}
|
||||
s.initPrivateInbound(http.NotFoundHandler(), nil)
|
||||
assert.Nil(t, s.inbound, "manager should remain nil when --private is off")
|
||||
}
|
||||
|
||||
// TestServer_PrivateInbound_Enabled_WiresLifecycle confirms that
|
||||
// --private alone wires the manager into the NetBird transport, so
|
||||
// AddPeer / RemovePeer drive the lifecycle.
|
||||
func TestServer_PrivateInbound_Enabled_WiresLifecycle(t *testing.T) {
|
||||
s := &Server{Logger: quietLogger(), Private: true}
|
||||
// Construct a NetBird transport. We can't actually start the embedded
|
||||
// client here (that needs a real management server), but we can
|
||||
// confirm that the lifecycle callbacks are registered.
|
||||
s.netbird = roundtrip.NewNetBird("test", "test", roundtrip.ClientConfig{
|
||||
MgmtAddr: "http://invalid.test",
|
||||
}, quietLogger(), nil, fakeMgmtClient{})
|
||||
|
||||
s.initPrivateInbound(http.NotFoundHandler(), &tls.Config{}) //nolint:gosec
|
||||
require.NotNil(t, s.inbound, "manager should be set when --private is on")
|
||||
assert.NotNil(t, s.inbound.handler, "handler should be set on manager")
|
||||
assert.NotNil(t, s.inbound.tlsConfig, "tls config should be set on manager")
|
||||
}
|
||||
|
||||
// TestInboundManager_AddRouteAfterReady_RegistersDirectly verifies that
|
||||
// when the listener is already up, AddRoute writes straight to the
|
||||
// router without queueing.
|
||||
func TestInboundManager_AddRouteAfterReady_RegistersDirectly(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
accountID := types.AccountID("acct-1")
|
||||
router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
|
||||
|
||||
mgr.muxLock.Lock()
|
||||
mgr.entries[accountID] = &inboundEntry{router: router}
|
||||
mgr.muxLock.Unlock()
|
||||
|
||||
host := nbtcp.SNIHost("ready.example")
|
||||
mgr.AddRoute(accountID, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-ready", Domain: string(host)})
|
||||
assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "no pending entries when listener is up")
|
||||
}
|
||||
|
||||
// TestPrivateCapability_DerivedFromPrivateOnly tests that the capability
|
||||
// bit reported upstream tracks --private exclusively. The previous
|
||||
// --private-inbound flag has been folded into --private.
|
||||
func TestPrivateCapability_DerivedFromPrivateOnly(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
private bool
|
||||
expected bool
|
||||
}{
|
||||
{"off", false, false},
|
||||
{"on", true, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Server{Private: tt.private}
|
||||
assert.Equal(t, tt.expected, s.Private, "private capability bit should match --private")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInboundManager_RouteScopedToAccountB_DoesNotMatchA verifies that a
|
||||
// service registered for account B is invisible to a router serving
|
||||
// account A. We exercise the path through real per-account routers.
|
||||
func TestInboundManager_RouteScopedToAccountB_DoesNotMatchA(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
|
||||
accountA := types.AccountID("acct-a")
|
||||
accountB := types.AccountID("acct-b")
|
||||
routerA := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
|
||||
routerB := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
|
||||
|
||||
mgr.muxLock.Lock()
|
||||
mgr.entries[accountA] = &inboundEntry{router: routerA}
|
||||
mgr.entries[accountB] = &inboundEntry{router: routerB}
|
||||
mgr.muxLock.Unlock()
|
||||
|
||||
host := nbtcp.SNIHost("shared.example")
|
||||
mgr.AddRoute(accountB, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: string(host)})
|
||||
|
||||
// Account A's router should have no routes; account B's should have one.
|
||||
// We check via IsEmpty — true means no routes and no fallback.
|
||||
assert.True(t, routerA.IsEmpty(), "account A router must not see account B's mappings")
|
||||
assert.False(t, routerB.IsEmpty(), "account B router should hold its own mapping")
|
||||
}
|
||||
|
||||
// TestInboundEntry_ShutdownIdempotent ensures that tearDown can run twice
|
||||
// without panicking — callers may invoke it from RemovePeer + StopAll.
|
||||
func TestInboundEntry_ShutdownIdempotent(t *testing.T) {
|
||||
t.Skip("teardown requires real netstack listeners; covered by integration tests")
|
||||
}
|
||||
|
||||
// TestRouter_PlainHTTP_ForwardedProtoIsHTTP exercises the full per-account
|
||||
// router pipeline against a loopback listener (proxy of a netstack
|
||||
// listener for test purposes): a plain HTTP request lands on the plain
|
||||
// http.Server and the inner handler observes a nil r.TLS, which is what
|
||||
// auth.ResolveProto translates to "http" in the real pipeline.
|
||||
func TestRouter_PlainHTTP_ForwardedProtoIsHTTP(t *testing.T) {
|
||||
logger := quietLogger()
|
||||
|
||||
var captured atomic.Value
|
||||
captured.Store("")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil {
|
||||
captured.Store("http")
|
||||
} else {
|
||||
captured.Store("https")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
hostListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "loopback listener bind must succeed")
|
||||
defer hostListener.Close()
|
||||
|
||||
router := nbtcp.NewRouter(logger, nil, hostListener.Addr(), nbtcp.WithPlainHTTP(hostListener.Addr()))
|
||||
httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second}
|
||||
defer func() { _ = httpServer.Close() }()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }()
|
||||
go func() { _ = router.Serve(ctx, hostListener) }()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", hostListener.Addr().String(), 2*time.Second)
|
||||
require.NoError(t, err, "plain HTTP dial must succeed")
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n"))
|
||||
require.NoError(t, err, "write must succeed")
|
||||
|
||||
resp, err := http.ReadResponse(bufioReader(conn), nil)
|
||||
require.NoError(t, err, "must read response")
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "http", captured.Load(), "ForwardedProto must be http on plain path")
|
||||
}
|
||||
|
||||
// TestWithTunnelLookup_AttachesLookupToContext verifies that requests
|
||||
// flowing through the per-account handler wrapper carry the peerstore
|
||||
// lookup function. Phase 3's local-first deny path depends on this.
|
||||
func TestWithTunnelLookup_AttachesLookupToContext(t *testing.T) {
|
||||
expected := auth.PeerIdentity{TunnelIP: netip.MustParseAddr("100.64.0.10"), FQDN: "peer.netbird"}
|
||||
lookup := auth.TunnelLookupFunc(func(_ netip.Addr) (auth.PeerIdentity, bool) {
|
||||
return expected, true
|
||||
})
|
||||
|
||||
var observed auth.TunnelLookupFunc
|
||||
inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
observed = auth.TunnelLookupFromContext(r.Context())
|
||||
})
|
||||
|
||||
handler := withTunnelLookup(inner, lookup)
|
||||
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
|
||||
handler.ServeHTTP(httptest.NewRecorder(), r)
|
||||
|
||||
require.NotNil(t, observed, "wrapper must inject the lookup into the request context")
|
||||
got, ok := observed(netip.MustParseAddr("100.64.0.10"))
|
||||
assert.True(t, ok, "lookup must round-trip through context")
|
||||
assert.Equal(t, expected.FQDN, got.FQDN, "lookup must return the same identity it was constructed with")
|
||||
}
|
||||
|
||||
// TestWithTunnelLookup_NilLookupIsNoop confirms the wrapper is a pure
|
||||
// pass-through when no lookup is provided. Required for the host-level
|
||||
// listener path to keep its byte-for-byte previous behaviour.
|
||||
func TestWithTunnelLookup_NilLookupIsNoop(t *testing.T) {
|
||||
var called bool
|
||||
inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
assert.Nil(t, auth.TunnelLookupFromContext(r.Context()), "host-level path must not see a lookup function")
|
||||
})
|
||||
|
||||
handler := withTunnelLookup(inner, nil)
|
||||
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
|
||||
handler.ServeHTTP(httptest.NewRecorder(), r)
|
||||
assert.True(t, called, "wrapper without lookup must still invoke next")
|
||||
}
|
||||
|
||||
// fakeListener satisfies net.Listener for snapshot tests without binding
|
||||
// a real socket on the netstack.
|
||||
type fakeListener struct {
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func (f *fakeListener) Accept() (net.Conn, error) { return nil, net.ErrClosed }
|
||||
func (f *fakeListener) Close() error { return nil }
|
||||
func (f *fakeListener) Addr() net.Addr { return f.addr }
|
||||
|
||||
// TestInboundManager_ListenerInfo confirms ListenerInfo and Snapshot
|
||||
// surface the bound tunnel-IP and ports for live entries.
|
||||
func TestInboundManager_ListenerInfo(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
accountID := types.AccountID("acct-info")
|
||||
|
||||
tlsAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTPS}
|
||||
plainAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTP}
|
||||
mgr.muxLock.Lock()
|
||||
mgr.entries[accountID] = &inboundEntry{
|
||||
tlsListener: &fakeListener{addr: tlsAddr},
|
||||
plainListener: &fakeListener{addr: plainAddr},
|
||||
}
|
||||
mgr.muxLock.Unlock()
|
||||
|
||||
info, ok := mgr.ListenerInfo(accountID)
|
||||
require.True(t, ok, "ListenerInfo must report ok for live entry")
|
||||
assert.Equal(t, "100.64.0.5", info.TunnelIP, "tunnel IP must come from listener address")
|
||||
assert.Equal(t, uint16(privateInboundPortHTTPS), info.HTTPSPort, "TLS port must match bound port")
|
||||
assert.Equal(t, uint16(privateInboundPortHTTP), info.HTTPPort, "HTTP port must match bound port")
|
||||
|
||||
snap := mgr.Snapshot()
|
||||
require.Len(t, snap, 1, "snapshot must contain exactly one entry")
|
||||
assert.Equal(t, info, snap[accountID], "snapshot entry must equal direct lookup")
|
||||
|
||||
_, ok = mgr.ListenerInfo(types.AccountID("missing"))
|
||||
assert.False(t, ok, "ListenerInfo must report ok=false for unknown accounts")
|
||||
}
|
||||
|
||||
// TestInboundManager_NilManagerSafe ensures the observability accessors
|
||||
// are safe to call when --private-inbound is off (nil manager).
|
||||
func TestInboundManager_NilManagerSafe(t *testing.T) {
|
||||
var mgr *inboundManager
|
||||
_, ok := mgr.ListenerInfo("anything")
|
||||
assert.False(t, ok, "nil manager must return ok=false")
|
||||
assert.Nil(t, mgr.Snapshot(), "nil manager must return nil snapshot")
|
||||
}
|
||||
|
||||
// TestInboundManager_ConcurrentAddRemove pounds AddRoute / RemoveRoute
|
||||
// from multiple goroutines to expose any locking gaps.
|
||||
func TestInboundManager_ConcurrentAddRemove(t *testing.T) {
|
||||
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
|
||||
accountID := types.AccountID("acct-1")
|
||||
const workers = 32
|
||||
const iterations = 50
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(workers)
|
||||
for i := 0; i < workers; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
host := nbtcp.SNIHost("example.test")
|
||||
svc := types.ServiceID("svc")
|
||||
route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: svc, Domain: "example.test"}
|
||||
for j := 0; j < iterations; j++ {
|
||||
mgr.AddRoute(accountID, host, route)
|
||||
mgr.RemoveRoute(accountID, host, svc)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("concurrent add/remove timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFeedRouterFromListener_DeliversConnectionToHandler validates the
|
||||
// per-account inbound chain end-to-end with a loopback listener
|
||||
// substituted for the embedded netstack: a TCP connection arriving at
|
||||
// the plain listener flows through feedRouterFromListener, the router's
|
||||
// peek-and-dispatch, the wrapped HTTP server, and reaches the user
|
||||
// handler. If the embedded netstack is delivering connections at all,
|
||||
// this is the path they take. Failures localise to wiring bugs in the
|
||||
// proxy, not the netstack.
|
||||
func TestFeedRouterFromListener_DeliversConnectionToHandler(t *testing.T) {
|
||||
logger := quietLogger()
|
||||
|
||||
hits := make(chan string, 1)
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits <- r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("served"))
|
||||
})
|
||||
|
||||
plainLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "plain loopback bind must succeed")
|
||||
t.Cleanup(func() { _ = plainLn.Close() })
|
||||
|
||||
router := nbtcp.NewRouter(logger, nil, &fakeAddr{addr: "127.0.0.1:0"}, nbtcp.WithPlainHTTP(plainLn.Addr()))
|
||||
|
||||
httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second}
|
||||
t.Cleanup(func() { _ = httpServer.Close() })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }()
|
||||
go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-1"))
|
||||
|
||||
conn, err := net.DialTimeout("tcp", plainLn.Addr().String(), 2*time.Second)
|
||||
require.NoError(t, err, "must connect to the plain listener")
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: app.example\r\nConnection: close\r\n\r\n"))
|
||||
require.NoError(t, err, "request write must succeed")
|
||||
|
||||
resp, err := http.ReadResponse(bufioReader(conn), nil)
|
||||
require.NoError(t, err, "must read response from server")
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode, "handler must be reached")
|
||||
|
||||
select {
|
||||
case host := <-hits:
|
||||
assert.Equal(t, "app.example", host, "handler must observe the request Host")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("handler was not invoked — connection did not flow through router → http server")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFeedRouterFromListener_DispatchesTLSToTLSChannel verifies that a
|
||||
// TLS ClientHello arriving on the plain listener is detected by the
|
||||
// router peek and re-dispatched to the TLS channel — the cross-channel
|
||||
// fallback the inbound stack relies on for HTTPS-on-:80 testing.
|
||||
func TestFeedRouterFromListener_DispatchesTLSToTLSChannel(t *testing.T) {
|
||||
logger := quietLogger()
|
||||
|
||||
hits := make(chan string, 1)
|
||||
tlsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits <- r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("served-tls"))
|
||||
})
|
||||
|
||||
plainLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "plain loopback bind must succeed")
|
||||
t.Cleanup(func() { _ = plainLn.Close() })
|
||||
|
||||
tlsLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "tls loopback bind must succeed")
|
||||
t.Cleanup(func() { _ = tlsLn.Close() })
|
||||
|
||||
router := nbtcp.NewRouter(logger, nil, tlsLn.Addr(), nbtcp.WithPlainHTTP(plainLn.Addr()))
|
||||
|
||||
tlsConfig := selfSignedTLSConfig(t)
|
||||
httpsServer := &http.Server{
|
||||
Handler: tlsHandler,
|
||||
TLSConfig: tlsConfig,
|
||||
ReadHeaderTimeout: time.Second,
|
||||
}
|
||||
t.Cleanup(func() { _ = httpsServer.Close() })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
go func() { _ = httpsServer.ServeTLS(router.HTTPListener(), "", "") }()
|
||||
go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-tls"))
|
||||
|
||||
tlsConn, err := tls.Dial("tcp", plainLn.Addr().String(), &tls.Config{InsecureSkipVerify: true}) //nolint:gosec
|
||||
require.NoError(t, err, "TLS dial against the plain listener must succeed (cross-channel)")
|
||||
t.Cleanup(func() { _ = tlsConn.Close() })
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://app.example/", nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, req.Write(tlsConn), "TLS request write must succeed")
|
||||
|
||||
resp, err := http.ReadResponse(bufioReader(tlsConn), req)
|
||||
require.NoError(t, err, "must read TLS response")
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode, "TLS handler must be reached")
|
||||
|
||||
select {
|
||||
case host := <-hits:
|
||||
assert.Equal(t, "app.example", host, "TLS handler must observe the request Host")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("TLS handler was not invoked — peek/dispatch path is broken")
|
||||
}
|
||||
}
|
||||
|
||||
func selfSignedTLSConfig(t *testing.T) *tls.Config {
|
||||
t.Helper()
|
||||
cert, err := tls.X509KeyPair(testCertPEM, testKeyPEM)
|
||||
require.NoError(t, err, "load static self-signed cert")
|
||||
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12} //nolint:gosec
|
||||
}
|
||||
|
||||
// testCertPEM / testKeyPEM are a minimal RSA self-signed cert for
|
||||
// 127.0.0.1 — only used by tests that need a working TLS handshake.
|
||||
var testCertPEM = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
|
||||
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
|
||||
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
|
||||
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
|
||||
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
|
||||
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
|
||||
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
|
||||
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
|
||||
6MF9+Yw1Yy0t
|
||||
-----END CERTIFICATE-----`)
|
||||
var testKeyPEM = []byte(`-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
|
||||
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
|
||||
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
|
||||
-----END EC PRIVATE KEY-----`)
|
||||
47
proxy/internal/auth/identity.go
Normal file
47
proxy/internal/auth/identity.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// PeerIdentity describes the locally-known facts about a peer reachable on
|
||||
// the proxy's per-account WireGuard listener. Phase 3 fills PubKey, TunnelIP
|
||||
// and FQDN from the embedded client's peerstore. UserID, Email and Groups
|
||||
// stay zero in V1 — full identity still travels through ValidateTunnelPeer.
|
||||
// Phase V2 will populate them once RemotePeerConfig carries user identity.
|
||||
type PeerIdentity struct {
|
||||
PubKey string
|
||||
TunnelIP netip.Addr
|
||||
FQDN string
|
||||
|
||||
// V2 fields (zero in V1).
|
||||
UserID string
|
||||
Email string
|
||||
Groups []string
|
||||
}
|
||||
|
||||
// TunnelLookupFunc resolves a tunnel IP to a peer identity using locally
|
||||
// available peerstore data. ok=false means the IP is not in the calling
|
||||
// account's roster.
|
||||
type TunnelLookupFunc func(ip netip.Addr) (PeerIdentity, bool)
|
||||
|
||||
type tunnelLookupContextKey struct{}
|
||||
|
||||
// WithTunnelLookup attaches a per-account peerstore lookup function to
|
||||
// the request context. The auth middleware calls this lookup before
|
||||
// hitting management's ValidateTunnelPeer to short-circuit unknown IPs
|
||||
// and to skip the RPC for already-cached identities.
|
||||
func WithTunnelLookup(ctx context.Context, lookup TunnelLookupFunc) context.Context {
|
||||
if lookup == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, tunnelLookupContextKey{}, lookup)
|
||||
}
|
||||
|
||||
// TunnelLookupFromContext returns the peerstore lookup attached to ctx,
|
||||
// or nil when the request did not arrive on a per-account listener.
|
||||
func TunnelLookupFromContext(ctx context.Context) TunnelLookupFunc {
|
||||
v, _ := ctx.Value(tunnelLookupContextKey{}).(TunnelLookupFunc)
|
||||
return v
|
||||
}
|
||||
@@ -36,6 +36,7 @@ type authenticator interface {
|
||||
// SessionValidator validates session tokens and checks user access permissions.
|
||||
type SessionValidator interface {
|
||||
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
|
||||
ValidateTunnelPeer(ctx context.Context, in *proto.ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error)
|
||||
}
|
||||
|
||||
// Scheme defines an authentication mechanism for a domain.
|
||||
@@ -56,12 +57,21 @@ type DomainConfig struct {
|
||||
AccountID types.AccountID
|
||||
ServiceID types.ServiceID
|
||||
IPRestrictions *restrict.Filter
|
||||
// Private routes the domain through ValidateTunnelPeer; failure → 403.
|
||||
Private bool
|
||||
}
|
||||
|
||||
type validationResult struct {
|
||||
UserID string
|
||||
UserEmail string
|
||||
Valid bool
|
||||
DeniedReason string
|
||||
Groups []string
|
||||
// GroupNames carries the human-readable display names for Groups,
|
||||
// ordered identically (positional pairing). May be shorter than
|
||||
// Groups for tokens minted before names were embedded; the consumer
|
||||
// falls back to ids for missing positions.
|
||||
GroupNames []string
|
||||
}
|
||||
|
||||
// Middleware applies per-domain authentication and IP restriction checks.
|
||||
@@ -71,6 +81,7 @@ type Middleware struct {
|
||||
logger *log.Logger
|
||||
sessionValidator SessionValidator
|
||||
geo restrict.GeoResolver
|
||||
tunnelCache *tunnelValidationCache
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new authentication middleware. The sessionValidator is
|
||||
@@ -84,6 +95,7 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo re
|
||||
logger: logger,
|
||||
sessionValidator: sessionValidator,
|
||||
geo: geo,
|
||||
tunnelCache: newTunnelValidationCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +123,15 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
// Private services bypass operator schemes and gate on tunnel peer.
|
||||
if config.Private {
|
||||
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Domains with no authentication schemes pass through after IP checks.
|
||||
if len(config.Schemes) == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -129,10 +150,54 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
|
||||
if mw.blockOIDCOnPlainHTTP(w, r, config) {
|
||||
return
|
||||
}
|
||||
|
||||
mw.authenticateWithSchemes(w, r, host, config)
|
||||
})
|
||||
}
|
||||
|
||||
// requestIsPlainHTTP reports whether the request arrived without TLS.
|
||||
// Used to gate cookie-on-plain warnings and the OIDC plain-HTTP block.
|
||||
func requestIsPlainHTTP(r *http.Request) bool {
|
||||
return r.TLS == nil
|
||||
}
|
||||
|
||||
// hasOIDCScheme reports whether any of the configured schemes requires
|
||||
// TLS to round-trip safely with an external IdP.
|
||||
func hasOIDCScheme(schemes []Scheme) bool {
|
||||
for _, s := range schemes {
|
||||
if s.Type() == auth.MethodOIDC {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// blockOIDCOnPlainHTTP fails fast when an OIDC-configured domain is hit
|
||||
// over plain HTTP. Most IdPs reject http:// redirect URIs, so surfacing
|
||||
// the misconfiguration here yields a clearer error than the IdP's
|
||||
// "invalid redirect_uri" round-trip.
|
||||
func (mw *Middleware) blockOIDCOnPlainHTTP(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
|
||||
if !requestIsPlainHTTP(r) {
|
||||
return false
|
||||
}
|
||||
if !hasOIDCScheme(config.Schemes) {
|
||||
return false
|
||||
}
|
||||
mw.logger.WithFields(log.Fields{
|
||||
"host": r.Host,
|
||||
"remote": r.RemoteAddr,
|
||||
}).Warn("OIDC scheme reached on plain HTTP path; rejecting with 400 — use port 443")
|
||||
http.Error(w, "OIDC requires TLS — use port 443", http.StatusBadRequest)
|
||||
return true
|
||||
}
|
||||
|
||||
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
|
||||
mw.domainsMux.RLock()
|
||||
defer mw.domainsMux.RUnlock()
|
||||
@@ -246,18 +311,117 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
|
||||
if requestIsPlainHTTP(r) {
|
||||
mw.logger.WithFields(log.Fields{
|
||||
"host": host,
|
||||
"remote": r.RemoteAddr,
|
||||
}).Warn("session cookie on plain HTTP path; cookie auth requires TLS — use port 443")
|
||||
}
|
||||
userID, email, method, groups, groupNames, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(userID)
|
||||
cd.SetUserEmail(email)
|
||||
cd.SetUserGroups(groups)
|
||||
cd.SetUserGroupNames(groupNames)
|
||||
cd.SetAuthMethod(method)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardWithTunnelPeer is the OIDC fast-path for requests originating on the
|
||||
// netbird mesh. When the source IP belongs to a private/CGNAT range the proxy
|
||||
// asks management to resolve it to a peer/user and to gate by the service's
|
||||
// distribution_groups. On success the proxy installs the freshly minted JWT
|
||||
// as a session cookie, sets UserID + Method=oidc on the captured data, and
|
||||
// forwards directly — operators see the same access-log shape as if the user
|
||||
// had completed an OIDC redirect. Any failure (private-range mismatch,
|
||||
// management unreachable, peer unknown, user not in group) returns false so
|
||||
// the caller falls back to the existing OIDC scheme dispatch.
|
||||
//
|
||||
// Phase 3 adds a local-first short-circuit: when the request arrived on a
|
||||
// per-account inbound listener the context carries a peerstore lookup
|
||||
// (TunnelLookupFromContext). If the lookup says the IP isn't in the account's
|
||||
// roster the proxy denies fast without calling management. If the lookup
|
||||
// confirms a known peer the RPC still runs for the user-identity tail
|
||||
// (UserID + group access), but its result is cached for tunnelCacheTTL so
|
||||
// repeat requests skip management entirely.
|
||||
func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
if mw.sessionValidator == nil {
|
||||
return false
|
||||
}
|
||||
clientIP := mw.resolveClientIP(r)
|
||||
if !clientIP.IsValid() {
|
||||
return false
|
||||
}
|
||||
if !isTunnelSourceIP(clientIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
if lookup := TunnelLookupFromContext(r.Context()); lookup != nil {
|
||||
if _, ok := lookup(clientIP); !ok {
|
||||
mw.logger.WithFields(log.Fields{
|
||||
"host": host,
|
||||
"remote": clientIP,
|
||||
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
resp, _, err := mw.tunnelCache.fetch(r.Context(), tunnelCacheKey{
|
||||
accountID: config.AccountID,
|
||||
tunnelIP: clientIP,
|
||||
domain: host,
|
||||
}, mw.validateTunnelPeer)
|
||||
if err != nil {
|
||||
mw.logger.WithError(err).Debug("ValidateTunnelPeer failed; falling back to OIDC")
|
||||
return false
|
||||
}
|
||||
if !resp.GetValid() || resp.GetSessionToken() == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
setSessionCookie(w, resp.GetSessionToken(), config.SessionExpiration)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(resp.GetUserId())
|
||||
cd.SetUserEmail(resp.GetUserEmail())
|
||||
cd.SetUserGroups(resp.GetPeerGroupIds())
|
||||
cd.SetUserGroupNames(resp.GetPeerGroupNames())
|
||||
cd.SetAuthMethod(auth.MethodOIDC.String())
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
// validateTunnelPeer adapts the SessionValidator interface to the cache's
|
||||
// validateTunnelPeerFn signature.
|
||||
func (mw *Middleware) validateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
return mw.sessionValidator.ValidateTunnelPeer(ctx, req)
|
||||
}
|
||||
|
||||
// cgnatPrefix covers RFC 6598 100.64.0.0/10, the CGNAT block NetBird
|
||||
// allocates tunnel addresses from by default. IsPrivate() doesn't include
|
||||
// it, so we check it explicitly.
|
||||
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
|
||||
|
||||
// isTunnelSourceIP reports whether ip falls within an address range typical
|
||||
// of NetBird tunnels: RFC1918 private space, IPv6 ULA, or CGNAT 100.64/10
|
||||
// (NetBird's default range). Loopback and link-local are excluded — the
|
||||
// fast-path is meant for peer-to-peer mesh traffic, not localhost.
|
||||
func isTunnelSourceIP(ip netip.Addr) bool {
|
||||
if !ip.IsValid() || ip.IsLoopback() || ip.IsLinkLocalUnicast() {
|
||||
return false
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
return true
|
||||
}
|
||||
return cgnatPrefix.Contains(ip)
|
||||
}
|
||||
|
||||
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
|
||||
// the request is forwarded directly (no redirect), which is important for API clients.
|
||||
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
@@ -286,7 +450,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
|
||||
if err != nil {
|
||||
setHeaderCapturedData(r.Context(), "")
|
||||
setHeaderCapturedData(r.Context(), "", "", nil, nil)
|
||||
status := http.StatusBadRequest
|
||||
msg := "invalid session token"
|
||||
if errors.Is(err, errValidationUnavailable) {
|
||||
@@ -298,7 +462,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
setHeaderCapturedData(r.Context(), result.UserID)
|
||||
setHeaderCapturedData(r.Context(), result.UserID, result.UserEmail, result.Groups, result.GroupNames)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return true
|
||||
}
|
||||
@@ -306,6 +470,9 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
setSessionCookie(w, token, config.SessionExpiration)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetUserEmail(result.UserEmail)
|
||||
cd.SetUserGroups(result.Groups)
|
||||
cd.SetUserGroupNames(result.GroupNames)
|
||||
cd.SetAuthMethod(auth.MethodHeader.String())
|
||||
}
|
||||
|
||||
@@ -315,7 +482,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
|
||||
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
|
||||
if errors.Is(err, ErrHeaderAuthFailed) {
|
||||
setHeaderCapturedData(r.Context(), "")
|
||||
setHeaderCapturedData(r.Context(), "", "", nil, nil)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return true
|
||||
}
|
||||
@@ -327,7 +494,7 @@ func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Reque
|
||||
return true
|
||||
}
|
||||
|
||||
func setHeaderCapturedData(ctx context.Context, userID string) {
|
||||
func setHeaderCapturedData(ctx context.Context, userID, userEmail string, groups, groupNames []string) {
|
||||
cd := proxy.CapturedDataFromContext(ctx)
|
||||
if cd == nil {
|
||||
return
|
||||
@@ -335,6 +502,9 @@ func setHeaderCapturedData(ctx context.Context, userID string) {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodHeader.String())
|
||||
cd.SetUserID(userID)
|
||||
cd.SetUserEmail(userEmail)
|
||||
cd.SetUserGroups(groups)
|
||||
cd.SetUserGroupNames(groupNames)
|
||||
}
|
||||
|
||||
// authenticateWithSchemes tries each configured auth scheme in order.
|
||||
@@ -405,6 +575,9 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetUserEmail(result.UserEmail)
|
||||
cd.SetUserGroups(result.Groups)
|
||||
cd.SetUserGroupNames(result.GroupNames)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
@@ -419,6 +592,9 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetUserEmail(result.UserEmail)
|
||||
cd.SetUserGroups(result.Groups)
|
||||
cd.SetUserGroupNames(result.GroupNames)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
@@ -454,12 +630,9 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// AddDomain registers authentication schemes for the given domain.
|
||||
// If schemes are provided, a valid session public key is required to sign/verify
|
||||
// session JWTs. Returns an error if the key is missing or invalid.
|
||||
// Callers must not serve the domain if this returns an error, to avoid
|
||||
// exposing an unauthenticated service.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
|
||||
// AddDomain registers authentication schemes for the given domain. With schemes a valid session public key is required.
|
||||
// private=true forces ValidateTunnelPeer enforcement (403 on failure) regardless of the schemes list.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter, private bool) error {
|
||||
if len(schemes) == 0 {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
@@ -467,6 +640,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
IPRestrictions: ipRestrictions,
|
||||
Private: private,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -488,6 +662,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
IPRestrictions: ipRestrictions,
|
||||
Private: private,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -518,18 +693,27 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
|
||||
}).Debug("Session validation denied")
|
||||
return &validationResult{
|
||||
UserID: resp.UserId,
|
||||
UserEmail: resp.GetUserEmail(),
|
||||
Valid: false,
|
||||
DeniedReason: resp.DeniedReason,
|
||||
Groups: resp.GetPeerGroupIds(),
|
||||
GroupNames: resp.GetPeerGroupNames(),
|
||||
}, nil
|
||||
}
|
||||
return &validationResult{UserID: resp.UserId, Valid: true}, nil
|
||||
return &validationResult{
|
||||
UserID: resp.UserId,
|
||||
UserEmail: resp.GetUserEmail(),
|
||||
Valid: true,
|
||||
Groups: resp.GetPeerGroupIds(),
|
||||
GroupNames: resp.GetPeerGroupNames(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
|
||||
userID, email, _, groups, groupNames, err := auth.ValidateSessionJWT(token, host, publicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &validationResult{UserID: userID, Valid: true}, nil
|
||||
return &validationResult{UserID: userID, UserEmail: email, Valid: true, Groups: groups, GroupNames: groupNames}, nil
|
||||
}
|
||||
|
||||
// stripSessionTokenParam returns the request URI with the session_token query
|
||||
|
||||
@@ -4,13 +4,16 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -62,7 +65,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -79,7 +82,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil, false)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -93,7 +96,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil, false)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decode session public key")
|
||||
|
||||
@@ -108,7 +111,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
|
||||
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil, false)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -121,7 +124,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false)
|
||||
require.NoError(t, err, "domains with no auth schemes should not require a key")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -137,8 +140,8 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil, false))
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
config := mw.domains["example.com"]
|
||||
@@ -154,7 +157,7 @@ func TestRemoveDomain(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
mw.RemoveDomain("example.com")
|
||||
|
||||
@@ -178,7 +181,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
|
||||
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -195,7 +198,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -216,7 +219,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -237,9 +240,9 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
@@ -262,15 +265,48 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
assert.Equal(t, "authenticated", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestProtect_SessionCookieGroupsPropagate verifies the cookie path lifts the
|
||||
// JWT's groups claim into CapturedData so policy-aware middlewares can
|
||||
// authorise without an extra management round-trip.
|
||||
func TestProtect_SessionCookieGroupsPropagate(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
groups := []string{"engineering", "sre"}
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, groups, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cd := proxy.CapturedDataFromContext(r.Context())
|
||||
require.NotNil(t, cd, "captured data must be present in request context")
|
||||
assert.Equal(t, "test-user", cd.GetUserID())
|
||||
assert.Equal(t, groups, cd.GetUserGroups(), "JWT groups claim must propagate to CapturedData")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "request with valid groups-bearing cookie must succeed")
|
||||
assert.Equal(t, groups, capturedData.GetUserGroups(), "CapturedData groups must be retained after handler completes")
|
||||
}
|
||||
|
||||
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
// Sign a token that expired 1 second ago.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, -time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
@@ -293,10 +329,10 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
// Token signed for a different domain audience.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "other.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
@@ -320,10 +356,10 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
// Token signed with a different private key.
|
||||
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
@@ -345,7 +381,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
scheme := &stubScheme{
|
||||
@@ -357,7 +393,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -410,7 +446,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -427,7 +463,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First scheme (PIN) always fails, second scheme (password) succeeds.
|
||||
@@ -446,7 +482,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -476,7 +512,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
return "invalid-jwt-token", "", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -500,7 +536,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil, false)
|
||||
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
|
||||
}
|
||||
|
||||
@@ -509,10 +545,10 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
// Attempt to overwrite with an invalid key.
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil, false)
|
||||
require.Error(t, err)
|
||||
|
||||
// The original valid config should still be intact.
|
||||
@@ -536,7 +572,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -563,7 +599,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -590,7 +626,7 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -678,7 +714,7 @@ func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}))
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -714,7 +750,7 @@ func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}))
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -755,7 +791,7 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}))
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -781,11 +817,12 @@ func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) {
|
||||
return "", oidcURL, nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
@@ -809,11 +846,12 @@ func TestProtect_OIDCWithOtherMethodShowsLoginPage(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
@@ -834,7 +872,7 @@ func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.Authenti
|
||||
// returns a signed session token when the expected header value is provided.
|
||||
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
|
||||
t.Helper()
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
@@ -852,7 +890,7 @@ func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
@@ -895,7 +933,7 @@ func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
// Also add a PIN scheme so we can verify fallthrough behavior.
|
||||
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -915,7 +953,7 @@ func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
|
||||
return &proto.AuthenticateResponse{Success: false}, nil
|
||||
}}
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -938,7 +976,7 @@ func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
|
||||
return nil, errors.New("gRPC unavailable")
|
||||
}}
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -955,7 +993,7 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -1006,7 +1044,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
ha := req.GetHeaderAuth()
|
||||
if ha != nil && accepted[ha.GetHeaderValue()] {
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
|
||||
}
|
||||
@@ -1015,7 +1053,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
|
||||
|
||||
// Single Header scheme (as if one entry existed), but the mock checks both values.
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "Authorization")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -1059,3 +1097,173 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
|
||||
assert.False(t, backendCalled, "unknown token should be rejected")
|
||||
})
|
||||
}
|
||||
|
||||
// TestProtect_OIDCOnPlainHTTP_BlockedWith400 verifies that when an OIDC
|
||||
// scheme is configured and the request arrived without TLS, the middleware
|
||||
// short-circuits with a 400 instead of dispatching to the IdP redirect.
|
||||
func TestProtect_OIDCOnPlainHTTP_BlockedWith400(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodOIDC,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "", "https://idp.example.com/authorize", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code, "OIDC over plain HTTP should be rejected")
|
||||
assert.Contains(t, rec.Body.String(), "OIDC requires TLS", "response body should explain the rejection")
|
||||
}
|
||||
|
||||
// TestProtect_OIDCOverTLS_NotBlocked confirms the same configuration works
|
||||
// over TLS — the block only fires on plain HTTP.
|
||||
func TestProtect_OIDCOverTLS_NotBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodOIDC,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "", "https://idp.example.com/authorize", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, rec.Code, "OIDC over TLS should redirect to IdP")
|
||||
}
|
||||
|
||||
// TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked confirms that the OIDC
|
||||
// block only fires when an OIDC scheme is configured. PIN-only domains
|
||||
// pass through normally on plain HTTP.
|
||||
func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP")
|
||||
}
|
||||
|
||||
// TestProtect_SessionCookieOnPlainHTTP_LogsWarn verifies that a request
|
||||
// carrying a valid session cookie over plain HTTP is still forwarded but
|
||||
// emits a WARN-level log line for the operator.
|
||||
func TestProtect_SessionCookieOnPlainHTTP_LogsWarn(t *testing.T) {
|
||||
logger, hook := newTestLogger()
|
||||
|
||||
mw := NewMiddleware(logger, nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "user-1", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.True(t, backendCalled, "backend should still be reached — we don't drop the cookie")
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var found bool
|
||||
for _, entry := range hook.entries() {
|
||||
if entry.Level == log.WarnLevel && strings.Contains(entry.Message, "session cookie on plain HTTP path") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected WARN log for session cookie on plain HTTP")
|
||||
}
|
||||
|
||||
// TestProtect_SessionCookieOverTLS_NoWarn confirms the WARN only fires
|
||||
// on plain HTTP — TLS requests with a session cookie behave as before.
|
||||
func TestProtect_SessionCookieOverTLS_NoWarn(t *testing.T) {
|
||||
logger, hook := newTestLogger()
|
||||
|
||||
mw := NewMiddleware(logger, nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "user-1", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
for _, entry := range hook.entries() {
|
||||
assert.NotContains(t, entry.Message, "session cookie on plain HTTP path", "no plain-HTTP cookie warn expected over TLS")
|
||||
}
|
||||
}
|
||||
|
||||
// captureHook is a minimal logrus hook that records emitted entries for
|
||||
// inspection in tests. It avoids pulling in the full sirupsen test
|
||||
// helpers package (which the rest of the codebase doesn't use).
|
||||
type captureHook struct {
|
||||
mu sync.Mutex
|
||||
records []log.Entry
|
||||
}
|
||||
|
||||
func (h *captureHook) Levels() []log.Level { return log.AllLevels }
|
||||
|
||||
func (h *captureHook) Fire(entry *log.Entry) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.records = append(h.records, *entry)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *captureHook) entries() []log.Entry {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return append([]log.Entry{}, h.records...)
|
||||
}
|
||||
|
||||
// newTestLogger builds an isolated logrus logger with a capture hook so
|
||||
// tests can assert on emitted records without contending on the global
|
||||
// logger.
|
||||
func newTestLogger() (*log.Logger, *captureHook) {
|
||||
hook := &captureHook{}
|
||||
logger := log.New()
|
||||
logger.SetOutput(io.Discard)
|
||||
logger.AddHook(hook)
|
||||
logger.SetLevel(log.DebugLevel)
|
||||
return logger, hook
|
||||
}
|
||||
|
||||
171
proxy/internal/auth/tunnel_cache.go
Normal file
171
proxy/internal/auth/tunnel_cache.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// tunnelCacheTTL caps how long a positive ValidateTunnelPeer result is
|
||||
// reused before re-fetching from management. 5 minutes balances freshness
|
||||
// against management load on busy mesh networks.
|
||||
const tunnelCacheTTL = 300 * time.Second
|
||||
|
||||
// tunnelCachePerAccount caps the number of cached identities per account.
|
||||
// Bounded eviction avoids memory growth in pathological cases (huge peer
|
||||
// roster, brief request bursts) while staying generous for normal use.
|
||||
const tunnelCachePerAccount = 1024
|
||||
|
||||
// tunnelCacheKey identifies a cached entry by tunnel IP and originating
|
||||
// account. Domain is part of the value, not the key, because the
|
||||
// management response is per (account, IP) — domain only gates whether a
|
||||
// re-fetch is needed if the operator is accessing a different service.
|
||||
type tunnelCacheKey struct {
|
||||
accountID types.AccountID
|
||||
tunnelIP netip.Addr
|
||||
domain string
|
||||
}
|
||||
|
||||
// tunnelCacheEntry stores a positive validation response with the time it
|
||||
// was minted. Entries past tunnelCacheTTL are treated as misses.
|
||||
type tunnelCacheEntry struct {
|
||||
resp *proto.ValidateTunnelPeerResponse
|
||||
cachedAt time.Time
|
||||
}
|
||||
|
||||
// tunnelValidationCache memoizes ValidateTunnelPeer responses keyed by
|
||||
// (accountID, tunnelIP, domain). Only successful, valid responses are
|
||||
// cached — denials skip the cache so policy changes apply immediately.
|
||||
// Single-flight de-duplicates concurrent fetches for the same key so a
|
||||
// burst of cold requests collapses into a single RPC.
|
||||
type tunnelValidationCache struct {
|
||||
mu sync.Mutex
|
||||
entries map[types.AccountID]*accountBucket
|
||||
flight singleflight.Group
|
||||
ttl time.Duration
|
||||
maxSize int
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// accountBucket holds the cached entries for a single account, with a
|
||||
// FIFO eviction queue used when the bucket exceeds maxSize.
|
||||
type accountBucket struct {
|
||||
items map[tunnelCacheKey]tunnelCacheEntry
|
||||
order []tunnelCacheKey
|
||||
}
|
||||
|
||||
// newTunnelValidationCache constructs a cache with default TTL and bounds.
|
||||
func newTunnelValidationCache() *tunnelValidationCache {
|
||||
return &tunnelValidationCache{
|
||||
entries: make(map[types.AccountID]*accountBucket),
|
||||
ttl: tunnelCacheTTL,
|
||||
maxSize: tunnelCachePerAccount,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// get returns a cached response for the key, or nil when missing or
|
||||
// expired. Expired entries are evicted lazily on read.
|
||||
func (c *tunnelValidationCache) get(key tunnelCacheKey) *proto.ValidateTunnelPeerResponse {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
bucket, ok := c.entries[key.accountID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
entry, ok := bucket.items[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if c.now().Sub(entry.cachedAt) > c.ttl {
|
||||
delete(bucket.items, key)
|
||||
bucket.order = removeKey(bucket.order, key)
|
||||
return nil
|
||||
}
|
||||
return entry.resp
|
||||
}
|
||||
|
||||
// put records a positive response under the key. Evicts the oldest entry
|
||||
// in the account's bucket when the bound is exceeded.
|
||||
func (c *tunnelValidationCache) put(key tunnelCacheKey, resp *proto.ValidateTunnelPeerResponse) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
bucket, ok := c.entries[key.accountID]
|
||||
if !ok {
|
||||
bucket = &accountBucket{items: make(map[tunnelCacheKey]tunnelCacheEntry)}
|
||||
c.entries[key.accountID] = bucket
|
||||
}
|
||||
if _, exists := bucket.items[key]; !exists {
|
||||
bucket.order = append(bucket.order, key)
|
||||
}
|
||||
bucket.items[key] = tunnelCacheEntry{resp: resp, cachedAt: c.now()}
|
||||
|
||||
for len(bucket.order) > c.maxSize {
|
||||
oldest := bucket.order[0]
|
||||
bucket.order = bucket.order[1:]
|
||||
delete(bucket.items, oldest)
|
||||
}
|
||||
}
|
||||
|
||||
// removeKey drops the first occurrence of needle from order. The cache
|
||||
// uses small slices so a linear scan is cheaper than a map+slice combo.
|
||||
func removeKey(order []tunnelCacheKey, needle tunnelCacheKey) []tunnelCacheKey {
|
||||
for i, k := range order {
|
||||
if k == needle {
|
||||
return append(order[:i], order[i+1:]...)
|
||||
}
|
||||
}
|
||||
return order
|
||||
}
|
||||
|
||||
// flightKey turns a cache key into a single-flight string. AccountID and
|
||||
// IP isolation by themselves are insufficient because different domains
|
||||
// for the same peer/account may have different group access.
|
||||
func flightKey(key tunnelCacheKey) string {
|
||||
return string(key.accountID) + "|" + key.tunnelIP.String() + "|" + key.domain
|
||||
}
|
||||
|
||||
// validateTunnelPeerFn is the RPC entry point the cache wraps. It matches
|
||||
// the SessionValidator.ValidateTunnelPeer signature without exposing the
|
||||
// gRPC option variadic, since callers don't need it on the cache hot path.
|
||||
type validateTunnelPeerFn func(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error)
|
||||
|
||||
// fetch returns a cached response when present, otherwise calls validate
|
||||
// under single-flight and caches the result. Denied responses pass
|
||||
// through but are not cached so policy changes apply immediately.
|
||||
func (c *tunnelValidationCache) fetch(ctx context.Context, key tunnelCacheKey, validate validateTunnelPeerFn) (*proto.ValidateTunnelPeerResponse, bool, error) {
|
||||
if resp := c.get(key); resp != nil {
|
||||
return resp, true, nil
|
||||
}
|
||||
|
||||
flight := flightKey(key)
|
||||
res, err, _ := c.flight.Do(flight, func() (any, error) {
|
||||
if cached := c.get(key); cached != nil {
|
||||
return cached, nil
|
||||
}
|
||||
resp, err := validate(ctx, &proto.ValidateTunnelPeerRequest{
|
||||
TunnelIp: key.tunnelIP.String(),
|
||||
Domain: key.domain,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.GetValid() && resp.GetSessionToken() != "" {
|
||||
c.put(key, resp)
|
||||
}
|
||||
return resp, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
resp, _ := res.(*proto.ValidateTunnelPeerResponse)
|
||||
return resp, false, nil
|
||||
}
|
||||
171
proxy/internal/auth/tunnel_cache_test.go
Normal file
171
proxy/internal/auth/tunnel_cache_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func newTestKey(account types.AccountID, ip string, domain string) tunnelCacheKey {
|
||||
return tunnelCacheKey{
|
||||
accountID: account,
|
||||
tunnelIP: netip.MustParseAddr(ip),
|
||||
domain: domain,
|
||||
}
|
||||
}
|
||||
|
||||
func TestTunnelCache_HitSkipsRPC(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
|
||||
|
||||
var calls int32
|
||||
validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}, nil
|
||||
}
|
||||
|
||||
resp, fromCache, err := cache.fetch(context.Background(), key, validate)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp, "first fetch returns RPC response")
|
||||
assert.False(t, fromCache, "first fetch must not be cached")
|
||||
|
||||
resp2, fromCache2, err := cache.fetch(context.Background(), key, validate)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp2, "second fetch returns cached response")
|
||||
assert.True(t, fromCache2, "second fetch must be served from cache")
|
||||
assert.Equal(t, "user-1", resp2.GetUserId(), "cached response should preserve user identity")
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "validate should run exactly once with one cache hit")
|
||||
}
|
||||
|
||||
func TestTunnelCache_ExpiredEntryRefetches(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
clock := time.Now()
|
||||
cache.now = func() time.Time { return clock }
|
||||
|
||||
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
|
||||
var calls int32
|
||||
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil
|
||||
}
|
||||
|
||||
_, _, err := cache.fetch(context.Background(), key, validate)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "first fetch issues one RPC")
|
||||
|
||||
clock = clock.Add(tunnelCacheTTL + time.Second)
|
||||
|
||||
_, fromCache, err := cache.fetch(context.Background(), key, validate)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, fromCache, "expired entry must miss the cache")
|
||||
assert.Equal(t, int32(2), atomic.LoadInt32(&calls), "expired entry forces a re-fetch")
|
||||
}
|
||||
|
||||
func TestTunnelCache_DeniedResponseNotCached(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
|
||||
|
||||
var calls int32
|
||||
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: false, DeniedReason: "not_in_group"}, nil
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
_, _, err := cache.fetch(context.Background(), key, validate)
|
||||
require.NoError(t, err, "fetch must not error on denied response")
|
||||
}
|
||||
assert.Equal(t, int32(3), atomic.LoadInt32(&calls), "denied responses bypass the cache so policy changes apply immediately")
|
||||
}
|
||||
|
||||
func TestTunnelCache_ConcurrentColdHitsCoalesce(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
|
||||
|
||||
gate := make(chan struct{})
|
||||
var calls int32
|
||||
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
<-gate
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil
|
||||
}
|
||||
|
||||
const workers = 16
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(workers)
|
||||
results := make([]bool, workers)
|
||||
for i := 0; i < workers; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
resp, _, err := cache.fetch(context.Background(), key, validate)
|
||||
results[idx] = err == nil && resp.GetValid()
|
||||
}(i)
|
||||
}
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
close(gate)
|
||||
wg.Wait()
|
||||
|
||||
for i, ok := range results {
|
||||
assert.Truef(t, ok, "worker %d should observe a successful response", i)
|
||||
}
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "single-flight must collapse concurrent cold fetches into one RPC")
|
||||
}
|
||||
|
||||
func TestTunnelCache_PerAccountIsolation(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
keyA := newTestKey("acct-a", "100.64.0.10", "svc.example")
|
||||
keyB := newTestKey("acct-b", "100.64.0.10", "svc.example")
|
||||
|
||||
var callsA, callsB int32
|
||||
validateA := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&callsA, 1)
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-a", UserId: "user-a"}, nil
|
||||
}
|
||||
validateB := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&callsB, 1)
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-b", UserId: "user-b"}, nil
|
||||
}
|
||||
|
||||
respA, _, err := cache.fetch(context.Background(), keyA, validateA)
|
||||
require.NoError(t, err)
|
||||
respB, _, err := cache.fetch(context.Background(), keyB, validateB)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "user-a", respA.GetUserId(), "account A response should belong to user-a")
|
||||
assert.Equal(t, "user-b", respB.GetUserId(), "account B response must not be served from account A's cache")
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&callsA), "validateA called exactly once")
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&callsB), "validateB called exactly once")
|
||||
}
|
||||
|
||||
func TestTunnelCache_BoundedSizeEvictsOldest(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
cache.maxSize = 2
|
||||
|
||||
validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-" + req.GetTunnelIp()}, nil
|
||||
}
|
||||
|
||||
keys := []tunnelCacheKey{
|
||||
newTestKey("acct-1", "100.64.0.10", "svc"),
|
||||
newTestKey("acct-1", "100.64.0.11", "svc"),
|
||||
newTestKey("acct-1", "100.64.0.12", "svc"),
|
||||
}
|
||||
for _, k := range keys {
|
||||
_, _, err := cache.fetch(context.Background(), k, validate)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Nil(t, cache.get(keys[0]), "oldest key should be evicted past maxSize")
|
||||
assert.NotNil(t, cache.get(keys[1]), "second-newest must remain cached")
|
||||
assert.NotNil(t, cache.get(keys[2]), "newest must remain cached")
|
||||
}
|
||||
325
proxy/internal/auth/tunnel_lookup_test.go
Normal file
325
proxy/internal/auth/tunnel_lookup_test.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// stubSessionValidator records ValidateTunnelPeer calls and returns the
|
||||
// pre-canned response. Counts let tests assert RPC traffic.
|
||||
type stubSessionValidator struct {
|
||||
respFn func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse
|
||||
respErr error
|
||||
tunnelCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (s *stubSessionValidator) ValidateSession(_ context.Context, _ *proto.ValidateSessionRequest, _ ...grpc.CallOption) (*proto.ValidateSessionResponse, error) {
|
||||
return &proto.ValidateSessionResponse{Valid: false}, nil
|
||||
}
|
||||
|
||||
func (s *stubSessionValidator) ValidateTunnelPeer(_ context.Context, in *proto.ValidateTunnelPeerRequest, _ ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
s.tunnelCalls.Add(1)
|
||||
if s.respErr != nil {
|
||||
return nil, s.respErr
|
||||
}
|
||||
if s.respFn != nil {
|
||||
return s.respFn(in), nil
|
||||
}
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: false}, nil
|
||||
}
|
||||
|
||||
func newTunnelMiddleware(t *testing.T, validator SessionValidator) *Middleware {
|
||||
t.Helper()
|
||||
mw := NewMiddleware(log.New(), validator, nil)
|
||||
require.NoError(t, mw.AddDomain("svc.example", nil, "", 0, "acct-1", "svc-1", nil, false))
|
||||
return mw
|
||||
}
|
||||
|
||||
func newTunnelRequest(remoteAddr string) (*httptest.ResponseRecorder, *http.Request) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
|
||||
r.Host = "svc.example"
|
||||
r.RemoteAddr = remoteAddr
|
||||
return w, r
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast verifies the
|
||||
// short-circuit: a tunnel IP not in the account's roster never reaches
|
||||
// management's ValidateTunnelPeer.
|
||||
func TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast(t *testing.T) {
|
||||
validator := &stubSessionValidator{}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
|
||||
return PeerIdentity{}, false
|
||||
})
|
||||
|
||||
w, r := newTunnelRequest("100.64.0.99:55555")
|
||||
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
|
||||
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
|
||||
assert.False(t, handled, "unknown peer must fall through, not forward")
|
||||
assert.False(t, called, "next handler must not run for unknown peer")
|
||||
assert.Equal(t, int32(0), validator.tunnelCalls.Load(), "ValidateTunnelPeer must be skipped on local-lookup miss")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_GroupsPropagateToCapturedData verifies the proxy
|
||||
// surfaces the calling peer's group memberships from ValidateTunnelPeerResponse
|
||||
// onto CapturedData so policy-aware middlewares can authorise without an
|
||||
// extra management round-trip.
|
||||
func TestForwardWithTunnelPeer_GroupsPropagateToCapturedData(t *testing.T) {
|
||||
groups := []string{"engineering", "sre"}
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: true,
|
||||
SessionToken: "tok",
|
||||
UserId: "user-1",
|
||||
PeerGroupIds: groups,
|
||||
}
|
||||
},
|
||||
}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
w, r := newTunnelRequest("100.64.0.10:55555")
|
||||
cd := proxy.NewCapturedData("")
|
||||
r = r.WithContext(proxy.WithCapturedData(r.Context(), cd))
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
|
||||
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
|
||||
require.True(t, handled, "valid tunnel-peer response must forward")
|
||||
require.True(t, called, "next handler must run")
|
||||
assert.Equal(t, "user-1", cd.GetUserID(), "user id must propagate from tunnel-peer response")
|
||||
assert.Equal(t, groups, cd.GetUserGroups(), "peer group IDs must propagate from tunnel-peer response")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs verifies that a
|
||||
// known tunnel IP still triggers ValidateTunnelPeer for the user-identity
|
||||
// tail (UserID + group access). Phase 3 only short-circuits the deny path.
|
||||
func TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
|
||||
},
|
||||
}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
knownIP := netip.MustParseAddr("100.64.0.10")
|
||||
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
|
||||
if ip == knownIP {
|
||||
return PeerIdentity{PubKey: "pk", TunnelIP: ip, FQDN: "peer.netbird.cloud"}, true
|
||||
}
|
||||
return PeerIdentity{}, false
|
||||
})
|
||||
|
||||
w, r := newTunnelRequest(knownIP.String() + ":55555")
|
||||
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
|
||||
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
|
||||
assert.True(t, handled, "known peer with valid RPC response must forward")
|
||||
assert.True(t, called, "next handler must run on success")
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC must run for the user-identity tail when local lookup confirms the peer")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath ensures the existing
|
||||
// behaviour stays intact on the host-level listener (no lookup attached).
|
||||
func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
|
||||
},
|
||||
}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
w, r := newTunnelRequest("100.64.0.10:55555")
|
||||
called := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
|
||||
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
|
||||
assert.True(t, handled, "host-level path forwards on positive RPC result")
|
||||
assert.True(t, called, "next handler runs on host-level success")
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "host-level path always RPCs (Phase 3 unchanged)")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_RPCErrorFallsThrough validates that an RPC
|
||||
// failure still falls through to the next scheme (no false positive).
|
||||
func TestForwardWithTunnelPeer_RPCErrorFallsThrough(t *testing.T) {
|
||||
validator := &stubSessionValidator{respErr: errors.New("management down")}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
knownIP := netip.MustParseAddr("100.64.0.10")
|
||||
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
|
||||
return PeerIdentity{TunnelIP: ip}, true
|
||||
})
|
||||
|
||||
w, r := newTunnelRequest(knownIP.String() + ":55555")
|
||||
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
|
||||
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||
|
||||
assert.False(t, handled, "RPC error must let the caller try other schemes")
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC was attempted exactly once")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_CacheReusesPositiveResponse confirms the
|
||||
// (account, IP, domain) cache prevents repeated RPCs for the same peer.
|
||||
func TestForwardWithTunnelPeer_CacheReusesPositiveResponse(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
|
||||
},
|
||||
}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
w, r := newTunnelRequest("100.64.0.10:55555")
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
require.True(t, handled, "iteration %d should forward", i)
|
||||
}
|
||||
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "subsequent forwards must hit the cache, not management")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey ensures cache keys
|
||||
// honour account scoping — same tunnel IP on different accounts must not
|
||||
// collide.
|
||||
func TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user"}
|
||||
},
|
||||
}
|
||||
mw := NewMiddleware(log.New(), validator, nil)
|
||||
|
||||
require.NoError(t, mw.AddDomain("svc-a.example", nil, "", 0, "acct-a", "svc-a", nil, false))
|
||||
require.NoError(t, mw.AddDomain("svc-b.example", nil, "", 0, "acct-b", "svc-b", nil, false))
|
||||
|
||||
for _, host := range []string{"svc-a.example", "svc-b.example"} {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "https://"+host+"/", nil)
|
||||
r.Host = host
|
||||
r.RemoteAddr = "100.64.0.10:55555"
|
||||
config, _ := mw.getDomainConfig(host)
|
||||
handled := mw.forwardWithTunnelPeer(w, r, host, config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||
require.True(t, handled, "host %s should forward", host)
|
||||
}
|
||||
|
||||
assert.Equal(t, int32(2), validator.tunnelCalls.Load(), "cache must not collide across accounts even when tunnel IPs match")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache
|
||||
// guarantees that the deny-fast path leaves the cache untouched, so a
|
||||
// subsequent request from the same IP after the peerstore catches up
|
||||
// goes through the normal RPC flow.
|
||||
func TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}
|
||||
},
|
||||
}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
knownIP := netip.MustParseAddr("100.64.0.10")
|
||||
known := false
|
||||
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
|
||||
if known && ip == knownIP {
|
||||
return PeerIdentity{TunnelIP: ip}, true
|
||||
}
|
||||
return PeerIdentity{}, false
|
||||
})
|
||||
|
||||
doRequest := func() bool {
|
||||
w, r := newTunnelRequest(knownIP.String() + ":55555")
|
||||
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
return mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||
}
|
||||
|
||||
require.False(t, doRequest(), "first request must short-circuit")
|
||||
require.Equal(t, int32(0), validator.tunnelCalls.Load(), "short-circuit must not populate the cache")
|
||||
|
||||
known = true
|
||||
require.True(t, doRequest(), "second request with peer in roster must forward via RPC")
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC runs once after peerstore catches up")
|
||||
}
|
||||
|
||||
func TestPrivateService_FailsClosedOnTunnelPeerFailure(t *testing.T) {
|
||||
mw := NewMiddleware(log.New(), nil, nil)
|
||||
require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true))
|
||||
|
||||
called := false
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
|
||||
req.Host = "private.svc"
|
||||
req.RemoteAddr = "100.64.0.10:55555"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
assert.False(t, called)
|
||||
}
|
||||
|
||||
func TestPrivateService_ForwardsOnTunnelPeerSuccess(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: true,
|
||||
SessionToken: "tok",
|
||||
UserId: "user-1",
|
||||
}
|
||||
},
|
||||
}
|
||||
mw := NewMiddleware(log.New(), validator, nil)
|
||||
require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true))
|
||||
|
||||
called := false
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
|
||||
req.Host = "private.svc"
|
||||
req.RemoteAddr = "100.64.0.10:55555"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.True(t, called)
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
)
|
||||
|
||||
// StatusFilters contains filter options for status queries.
|
||||
@@ -160,6 +159,49 @@ func (c *Client) printClients(data map[string]any) {
|
||||
for _, item := range clients {
|
||||
c.printClientRow(item)
|
||||
}
|
||||
|
||||
c.printInboundListeners(clients)
|
||||
}
|
||||
|
||||
func (c *Client) printInboundListeners(clients []any) {
|
||||
type row struct {
|
||||
accountID string
|
||||
tunnelIP string
|
||||
httpsPort int
|
||||
httpPort int
|
||||
}
|
||||
var rows []row
|
||||
for _, item := range clients {
|
||||
client, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
inbound, ok := client["inbound_listener"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
tunnelIP, _ := inbound["tunnel_ip"].(string)
|
||||
httpsPort, _ := inbound["https_port"].(float64)
|
||||
httpPort, _ := inbound["http_port"].(float64)
|
||||
accountID, _ := client["account_id"].(string)
|
||||
rows = append(rows, row{
|
||||
accountID: accountID,
|
||||
tunnelIP: tunnelIP,
|
||||
httpsPort: int(httpsPort),
|
||||
httpPort: int(httpPort),
|
||||
})
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(c.out)
|
||||
_, _ = fmt.Fprintln(c.out, "Inbound listeners (per-account):")
|
||||
_, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7s %s\n", "ACCOUNT ID", "TUNNEL IP", "HTTPS", "HTTP")
|
||||
_, _ = fmt.Fprintln(c.out, " "+strings.Repeat("-", 78))
|
||||
for _, r := range rows {
|
||||
_, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7d %d\n", r.accountID, r.tunnelIP, r.httpsPort, r.httpPort)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) printClientRow(item any) {
|
||||
@@ -219,7 +261,14 @@ func (c *Client) ClientStatus(ctx context.Context, accountID string, filters Sta
|
||||
}
|
||||
|
||||
func (c *Client) printClientStatus(data map[string]any) {
|
||||
_, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"])
|
||||
_, _ = fmt.Fprintf(c.out, "Account: %v\n", data["account_id"])
|
||||
if inbound, ok := data["inbound_listener"].(map[string]any); ok {
|
||||
tunnelIP, _ := inbound["tunnel_ip"].(string)
|
||||
httpsPort, _ := inbound["https_port"].(float64)
|
||||
httpPort, _ := inbound["http_port"].(float64)
|
||||
_, _ = fmt.Fprintf(c.out, "Inbound listener: %s (https=%d, http=%d)\n", tunnelIP, int(httpsPort), int(httpPort))
|
||||
}
|
||||
_, _ = fmt.Fprintln(c.out)
|
||||
if status, ok := data["status"].(string); ok {
|
||||
_, _ = fmt.Fprint(c.out, status)
|
||||
}
|
||||
|
||||
@@ -61,6 +61,23 @@ type clientProvider interface {
|
||||
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
|
||||
}
|
||||
|
||||
// InboundListenerInfo describes a per-account inbound listener as
|
||||
// surfaced through the debug HTTP handler. Mirrors the proto sub-message
|
||||
// emitted with SendStatusUpdate so dashboards and CLI tooling see the
|
||||
// same shape.
|
||||
type InboundListenerInfo struct {
|
||||
TunnelIP string `json:"tunnel_ip"`
|
||||
HTTPSPort uint16 `json:"https_port"`
|
||||
HTTPPort uint16 `json:"http_port"`
|
||||
}
|
||||
|
||||
// InboundProvider exposes per-account inbound listener state. Optional;
|
||||
// when nil the debug endpoint omits the inbound section entirely so the
|
||||
// existing JSON shape stays additive.
|
||||
type InboundProvider interface {
|
||||
InboundListeners() map[types.AccountID]InboundListenerInfo
|
||||
}
|
||||
|
||||
// healthChecker provides health probe state.
|
||||
type healthChecker interface {
|
||||
ReadinessProbe() bool
|
||||
@@ -80,6 +97,7 @@ type Handler struct {
|
||||
provider clientProvider
|
||||
health healthChecker
|
||||
certStatus certStatus
|
||||
inbound InboundProvider
|
||||
logger *log.Logger
|
||||
startTime time.Time
|
||||
templates *template.Template
|
||||
@@ -108,6 +126,13 @@ func (h *Handler) SetCertStatus(cs certStatus) {
|
||||
h.certStatus = cs
|
||||
}
|
||||
|
||||
// SetInboundProvider wires per-account inbound listener observability.
|
||||
// Pass nil (or skip the call) to keep the inbound section out of debug
|
||||
// responses on proxies that don't run --private-inbound.
|
||||
func (h *Handler) SetInboundProvider(p InboundProvider) {
|
||||
h.inbound = p
|
||||
}
|
||||
|
||||
func (h *Handler) loadTemplates() error {
|
||||
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
|
||||
if err != nil {
|
||||
@@ -323,23 +348,35 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
||||
sortedIDs := sortedAccountIDs(clients)
|
||||
|
||||
if wantJSON {
|
||||
var inboundAll map[types.AccountID]InboundListenerInfo
|
||||
if h.inbound != nil {
|
||||
inboundAll = h.inbound.InboundListeners()
|
||||
}
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
row := map[string]interface{}{
|
||||
"account_id": info.AccountID,
|
||||
"service_count": info.ServiceCount,
|
||||
"service_keys": info.ServiceKeys,
|
||||
"has_client": info.HasClient,
|
||||
"created_at": info.CreatedAt,
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
if inb, ok := inboundAll[id]; ok {
|
||||
row["inbound_listener"] = inb
|
||||
}
|
||||
clientsJSON = append(clientsJSON, row)
|
||||
}
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
resp := map[string]interface{}{
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
"clients": clientsJSON,
|
||||
})
|
||||
}
|
||||
if len(inboundAll) > 0 {
|
||||
resp["inbound_listener_count"] = len(inboundAll)
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -421,10 +458,14 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
|
||||
})
|
||||
|
||||
if wantJSON {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
resp := map[string]interface{}{
|
||||
"account_id": accountID,
|
||||
"status": overview.FullDetailSummary(),
|
||||
})
|
||||
}
|
||||
if info, ok := h.inboundInfoFor(accountID); ok {
|
||||
resp["inbound_listener"] = info
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -437,6 +478,18 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
|
||||
h.renderTemplate(w, "clientDetail", data)
|
||||
}
|
||||
|
||||
// inboundInfoFor returns the inbound listener info for an account, or
|
||||
// ok=false when no inbound provider is wired or the account has no live
|
||||
// listener.
|
||||
func (h *Handler) inboundInfoFor(accountID types.AccountID) (InboundListenerInfo, bool) {
|
||||
if h.inbound == nil {
|
||||
return InboundListenerInfo{}, false
|
||||
}
|
||||
all := h.inbound.InboundListeners()
|
||||
info, ok := all[accountID]
|
||||
return info, ok
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
|
||||
@@ -52,8 +52,15 @@ type CapturedData struct {
|
||||
origin ResponseOrigin
|
||||
clientIP netip.Addr
|
||||
userID string
|
||||
authMethod string
|
||||
metadata map[string]string
|
||||
userEmail string
|
||||
userGroups []string
|
||||
// userGroupNames pairs positionally with userGroups; populated from
|
||||
// the JWT's group_names claim or from ValidateSession/Tunnel
|
||||
// responses. Slice may be shorter than userGroups for tokens minted
|
||||
// before names were resolvable.
|
||||
userGroupNames []string
|
||||
authMethod string
|
||||
metadata map[string]string
|
||||
}
|
||||
|
||||
// NewCapturedData creates a CapturedData with the given request ID.
|
||||
@@ -138,6 +145,81 @@ func (c *CapturedData) GetUserID() string {
|
||||
return c.userID
|
||||
}
|
||||
|
||||
// SetUserEmail records the authenticated user's email address. Used by
|
||||
// policy-aware middlewares to stamp identity onto upstream requests
|
||||
// (e.g. x-litellm-end-user-id) without a management round-trip.
|
||||
func (c *CapturedData) SetUserEmail(email string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.userEmail = email
|
||||
}
|
||||
|
||||
// GetUserEmail returns the authenticated user's email address. Returns
|
||||
// the empty string when the auth path didn't carry an email (e.g.
|
||||
// non-OIDC schemes or legacy JWTs minted before the email claim).
|
||||
func (c *CapturedData) GetUserEmail() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.userEmail
|
||||
}
|
||||
|
||||
// SetUserGroups records the authenticated user's group memberships so
|
||||
// downstream policy-aware middlewares can authorise the request without
|
||||
// an additional management round-trip. The auth middleware populates this
|
||||
// from ValidateSessionResponse / ValidateTunnelPeerResponse and from the
|
||||
// session JWT's groups claim on cookie-bearing requests.
|
||||
func (c *CapturedData) SetUserGroups(groups []string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if len(groups) == 0 {
|
||||
c.userGroups = nil
|
||||
return
|
||||
}
|
||||
c.userGroups = append(c.userGroups[:0], groups...)
|
||||
}
|
||||
|
||||
// GetUserGroups returns a copy of the authenticated user's group
|
||||
// memberships.
|
||||
func (c *CapturedData) GetUserGroups() []string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if len(c.userGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(c.userGroups))
|
||||
copy(out, c.userGroups)
|
||||
return out
|
||||
}
|
||||
|
||||
// SetUserGroupNames records the human-readable display names for the
|
||||
// user's groups, ordered identically to UserGroups (positional
|
||||
// pairing). Stamped onto upstream requests as X-NetBird-Groups so
|
||||
// downstream services can read names rather than opaque ids.
|
||||
func (c *CapturedData) SetUserGroupNames(names []string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if len(names) == 0 {
|
||||
c.userGroupNames = nil
|
||||
return
|
||||
}
|
||||
c.userGroupNames = append(c.userGroupNames[:0], names...)
|
||||
}
|
||||
|
||||
// GetUserGroupNames returns a copy of the authenticated user's group
|
||||
// display names. Position i pairs with UserGroups[i]. May be shorter
|
||||
// than UserGroups for tokens minted before names were resolvable; the
|
||||
// consumer should fall back to ids for missing positions.
|
||||
func (c *CapturedData) GetUserGroupNames() []string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if len(c.userGroupNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(c.userGroupNames))
|
||||
copy(out, c.userGroupNames)
|
||||
return out
|
||||
}
|
||||
|
||||
// SetAuthMethod sets the authentication method used.
|
||||
func (c *CapturedData) SetAuthMethod(method string) {
|
||||
c.mu.Lock()
|
||||
|
||||
@@ -86,6 +86,9 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if pt.RequestTimeout > 0 {
|
||||
ctx = types.WithDialTimeout(ctx, pt.RequestTimeout)
|
||||
}
|
||||
if pt.DirectUpstream {
|
||||
ctx = roundtrip.WithDirectUpstream(ctx)
|
||||
}
|
||||
|
||||
rewriteMatchedPath := result.matchedPath
|
||||
if pt.PathRewrite == PathRewritePreserve {
|
||||
@@ -142,6 +145,8 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
|
||||
r.Out.Header.Set(k, v)
|
||||
}
|
||||
|
||||
stampNetBirdIdentity(r)
|
||||
|
||||
clientIP := extractHostIP(r.In.RemoteAddr)
|
||||
|
||||
if isTrustedAddr(clientIP, p.trustedProxies) {
|
||||
@@ -426,3 +431,46 @@ func opErrorContains(err error, substr string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
// headerNetBirdUser carries the authenticated user's display identity
|
||||
// (email when the peer is attached to a user, else peer name) onto
|
||||
// upstream requests. Stripped from inbound requests before stamping
|
||||
// so a client can't spoof identity by setting the header themselves.
|
||||
headerNetBirdUser = "X-NetBird-User"
|
||||
// headerNetBirdGroups carries the user's group display names as a
|
||||
// comma-separated list. Falls back to group IDs at positions where a
|
||||
// name wasn't available at session-mint time.
|
||||
headerNetBirdGroups = "X-NetBird-Groups"
|
||||
)
|
||||
|
||||
// stampNetBirdIdentity injects authenticated identity onto outbound
|
||||
// requests as X-NetBird-User and X-NetBird-Groups. Always strips any
|
||||
// client-sent values first (anti-spoof). Skips when the request didn't
|
||||
// carry CapturedData (early-path errors, internal endpoints).
|
||||
func stampNetBirdIdentity(r *httputil.ProxyRequest) {
|
||||
r.Out.Header.Del(headerNetBirdUser)
|
||||
r.Out.Header.Del(headerNetBirdGroups)
|
||||
|
||||
cd := CapturedDataFromContext(r.In.Context())
|
||||
if cd == nil {
|
||||
return
|
||||
}
|
||||
if email := cd.GetUserEmail(); email != "" {
|
||||
r.Out.Header.Set(headerNetBirdUser, email)
|
||||
}
|
||||
groupIDs := cd.GetUserGroups()
|
||||
if len(groupIDs) == 0 {
|
||||
return
|
||||
}
|
||||
groupNames := cd.GetUserGroupNames()
|
||||
labels := make([]string, len(groupIDs))
|
||||
for i, id := range groupIDs {
|
||||
if i < len(groupNames) && groupNames[i] != "" {
|
||||
labels[i] = groupNames[i]
|
||||
continue
|
||||
}
|
||||
labels[i] = id
|
||||
}
|
||||
r.Out.Header.Set(headerNetBirdGroups, strings.Join(labels, ","))
|
||||
}
|
||||
|
||||
@@ -1067,3 +1067,144 @@ func TestClassifyProxyError(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStampNetBirdIdentity_NoCapturedData_StripsOnly(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
|
||||
pr.In.Header.Set(headerNetBirdGroups, "admin")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
|
||||
"client-supplied X-NetBird-User must be stripped when no captured identity is present")
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"client-supplied X-NetBird-Groups must be stripped when no captured identity is present")
|
||||
}
|
||||
|
||||
func TestStampNetBirdIdentity_StampsFromCapturedData(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserEmail("alice@netbird.io")
|
||||
cd.SetUserGroups([]string{"grp-eng", "grp-ops"})
|
||||
cd.SetUserGroupNames([]string{"engineering", "operations"})
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "alice@netbird.io", pr.Out.Header.Get(headerNetBirdUser),
|
||||
"captured email must overwrite any spoofed value")
|
||||
assert.Equal(t, "engineering,operations", pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"group display names must be CSV-joined in positional order")
|
||||
}
|
||||
|
||||
// TestStampNetBirdIdentity_GroupsOnlyWhenEmailEmpty covers the
|
||||
// tunnel-peer-without-user case (machine agents, unattached proxy peers).
|
||||
// The proxy must still stamp the peer's groups so downstream services can
|
||||
// authorise, but X-NetBird-User stays unset — only its inbound stripping
|
||||
// must happen.
|
||||
func TestStampNetBirdIdentity_GroupsOnlyWhenEmailEmpty(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserGroups([]string{"grp-machines"})
|
||||
cd.SetUserGroupNames([]string{"machines"})
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
|
||||
"X-NetBird-User must remain unset when CapturedData carries no email")
|
||||
assert.Equal(t, "machines", pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"groups must still be stamped for peers without a user identity")
|
||||
}
|
||||
|
||||
// TestStampNetBirdIdentity_EmailOnlyWhenGroupsEmpty covers the symmetric
|
||||
// case: identity-resolved user without resolved group memberships.
|
||||
func TestStampNetBirdIdentity_EmailOnlyWhenGroupsEmpty(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserEmail("carol@netbird.io")
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "carol@netbird.io", pr.Out.Header.Get(headerNetBirdUser),
|
||||
"email must be stamped even when no groups are captured")
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"X-NetBird-Groups must remain unset when CapturedData carries no groups")
|
||||
}
|
||||
|
||||
func TestStampNetBirdIdentity_FallsBackToGroupIDsWhenNameMissing(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserEmail("bob@netbird.io")
|
||||
cd.SetUserGroups([]string{"grp-a", "grp-b", "grp-c"})
|
||||
// "grp-b" gets an explicit empty-string display name (not just a
|
||||
// shorter slice). Both gap shapes must fall back to the id.
|
||||
cd.SetUserGroupNames([]string{"alpha", "", ""})
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "alpha,grp-b,grp-c", pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"empty-string and out-of-range name slots must both fall back to the group id")
|
||||
}
|
||||
|
||||
// TestStampNetBirdIdentity_CapturedDataPresentButEmpty covers requests
|
||||
// that carry CapturedData with no identity fields populated (e.g. the
|
||||
// auth middleware ran but the request didn't authenticate). Both
|
||||
// headers must be cleared and neither stamped.
|
||||
func TestStampNetBirdIdentity_CapturedDataPresentButEmpty(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
|
||||
pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
|
||||
"X-NetBird-User must be stripped when CapturedData has no email")
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"X-NetBird-Groups must be stripped when CapturedData has no groups")
|
||||
}
|
||||
|
||||
@@ -28,6 +28,10 @@ type PathTarget struct {
|
||||
RequestTimeout time.Duration
|
||||
PathRewrite PathRewriteMode
|
||||
CustomHeaders map[string]string
|
||||
// DirectUpstream selects the stdlib HTTP transport (host network stack)
|
||||
// over the embedded NetBird WireGuard client when forwarding requests
|
||||
// to this target. Default false → embedded client (existing behaviour).
|
||||
DirectUpstream bool
|
||||
}
|
||||
|
||||
// Mapping describes how a domain is routed by the HTTP reverse proxy.
|
||||
|
||||
68
proxy/internal/roundtrip/multi.go
Normal file
68
proxy/internal/roundtrip/multi.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MultiTransport dispatches each request to either the embedded NetBird
|
||||
// http.RoundTripper or a stdlib http.Transport based on a per-request
|
||||
// context flag set by the reverse-proxy rewrite step. When the flag is
|
||||
// absent (the default for every existing target), requests follow the
|
||||
// embedded NetBird path — current behaviour, preserved.
|
||||
//
|
||||
// The stdlib branch is used when a target was configured with
|
||||
// direct_upstream=true. It dials via the host's network stack, which is
|
||||
// what private (`netbird proxy`) deployments and centralised proxies
|
||||
// fronting host-reachable upstreams (public APIs, LAN services,
|
||||
// localhost sidecars) want.
|
||||
type MultiTransport struct {
|
||||
embedded http.RoundTripper
|
||||
direct *http.Transport
|
||||
insecure *http.Transport
|
||||
}
|
||||
|
||||
// NewMultiTransport wires both branches. embedded is the existing NetBird
|
||||
// roundtripper; the direct branches are constructed here with sensible
|
||||
// defaults that mirror Go's stdlib defaults plus a dial-timeout wrapper
|
||||
// honouring the per-request value attached via types.WithDialTimeout.
|
||||
// Pass embedded=nil to disable the WG branch entirely (every request
|
||||
// will route direct, regardless of the context flag).
|
||||
func NewMultiTransport(embedded http.RoundTripper) *MultiTransport {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
direct := &http.Transport{
|
||||
DialContext: dialWithTimeout(dialer.DialContext),
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
insecure := direct.Clone()
|
||||
insecure.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // matches the embedded NetBird transport's per-target opt-in
|
||||
|
||||
return &MultiTransport{
|
||||
embedded: embedded,
|
||||
direct: direct,
|
||||
insecure: insecure,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip dispatches by reading the direct-upstream flag from the request
|
||||
// context. When set, the request is forwarded via the stdlib transport,
|
||||
// honouring the existing per-request skip-TLS-verify flag. Otherwise it
|
||||
// goes through the embedded NetBird roundtripper.
|
||||
func (m *MultiTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if DirectUpstreamFromContext(req.Context()) || m.embedded == nil {
|
||||
if skipTLSVerifyFromContext(req.Context()) {
|
||||
return m.insecure.RoundTrip(req)
|
||||
}
|
||||
return m.direct.RoundTrip(req)
|
||||
}
|
||||
return m.embedded.RoundTrip(req)
|
||||
}
|
||||
79
proxy/internal/roundtrip/multi_test.go
Normal file
79
proxy/internal/roundtrip/multi_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// stubRoundTripper records whether RoundTrip was called and returns a
|
||||
// canned response so tests can assert the dispatch decision without
|
||||
// running a real network.
|
||||
type stubRoundTripper struct {
|
||||
called bool
|
||||
body string
|
||||
}
|
||||
|
||||
func (s *stubRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
s.called = true
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(s.body)),
|
||||
Header: http.Header{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestMultiTransport_DispatchesByContextFlag(t *testing.T) {
|
||||
embedded := &stubRoundTripper{body: "embedded"}
|
||||
mt := NewMultiTransport(embedded)
|
||||
|
||||
t.Run("default routes to embedded", func(t *testing.T) {
|
||||
embedded.called = false
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil)
|
||||
resp, err := mt.RoundTrip(req)
|
||||
require.NoError(t, err, "embedded path must not error on stubbed transport")
|
||||
require.NotNil(t, resp)
|
||||
_ = resp.Body.Close()
|
||||
assert.True(t, embedded.called, "request without WithDirectUpstream must hit the embedded transport")
|
||||
})
|
||||
|
||||
t.Run("WithDirectUpstream skips embedded", func(t *testing.T) {
|
||||
embedded.called = false
|
||||
// Hit a server we control to verify the stdlib transport is used.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(WithDirectUpstream(context.Background()), http.MethodGet, srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := mt.RoundTrip(req)
|
||||
require.NoError(t, err, "direct path must dial via stdlib transport")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "direct", string(body), "stdlib transport must reach the test server")
|
||||
assert.False(t, embedded.called, "WithDirectUpstream must bypass the embedded transport")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiTransport_NilEmbeddedAlwaysDirects(t *testing.T) {
|
||||
mt := NewMultiTransport(nil)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := mt.RoundTrip(req)
|
||||
require.NoError(t, err, "nil embedded must fall through to direct without panic")
|
||||
_ = resp.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -76,11 +77,11 @@ type clientEntry struct {
|
||||
services map[ServiceKey]serviceInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// ready is closed once the client has been fully initialized.
|
||||
// Callers that find a pending entry wait on this channel before
|
||||
// accessing the client. A nil initErr means success.
|
||||
ready chan struct{}
|
||||
initErr error
|
||||
// inbound is opaque per-account state owned by the NetBird parent's
|
||||
// ReadyHandler. The roundtrip package never inspects this value; it
|
||||
// only stores it so RemovePeer / StopAll can hand it back to the
|
||||
// matching StopHandler. Nil when no inbound integration is active.
|
||||
inbound any
|
||||
// Per-backend in-flight limiting keyed by target host:port.
|
||||
// TODO: clean up stale entries when backend targets change.
|
||||
inflightMu sync.Mutex
|
||||
@@ -88,6 +89,19 @@ type clientEntry struct {
|
||||
maxInflight int
|
||||
}
|
||||
|
||||
// IdentityForIP resolves a tunnel IP to the peer identity locally known by
|
||||
// this account's embedded client. Returns (pubKey, fqdn) on success.
|
||||
// ok=false means the IP is not in the account's roster — callers can use
|
||||
// that as a fast deny without round-tripping management. The returned
|
||||
// strings carry only what the embedded peerstore exposes; user identity
|
||||
// (UserID / Email / Groups) still flows through ValidateTunnelPeer.
|
||||
func (e *clientEntry) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if e == nil || e.client == nil || !ip.IsValid() {
|
||||
return "", "", false
|
||||
}
|
||||
return e.client.IdentityForIP(ip)
|
||||
}
|
||||
|
||||
// acquireInflight attempts to acquire an in-flight slot for the given backend.
|
||||
// It returns a release function that must always be called, and true on success.
|
||||
func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) {
|
||||
@@ -117,6 +131,12 @@ type ClientConfig struct {
|
||||
MgmtAddr string
|
||||
WGPort uint16
|
||||
PreSharedKey string
|
||||
// BlockInbound mirrors embed.Options.BlockInbound. Set to true on the
|
||||
// standalone proxy where the embedded client never accepts inbound;
|
||||
// set to false on the private/embedded proxy so the engine creates
|
||||
// the ACL manager and applies management's per-policy firewall rules
|
||||
// (which is what gates per-account inbound listeners on the netstack).
|
||||
BlockInbound bool
|
||||
}
|
||||
|
||||
type statusNotifier interface {
|
||||
@@ -142,6 +162,14 @@ type NetBird struct {
|
||||
clients map[types.AccountID]*clientEntry
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
// readyHandler runs after the embedded client for an account reports
|
||||
// Ready. The opaque return value is stored on clientEntry and handed
|
||||
// back to stopHandler when the entry is torn down. Nil disables the
|
||||
// hook entirely (default for the standalone proxy).
|
||||
readyHandler func(ctx context.Context, accountID types.AccountID, client *embed.Client) any
|
||||
// stopHandler runs when an account's last service is removed (or the
|
||||
// transport is shutting down). Receives whatever readyHandler returned.
|
||||
stopHandler func(accountID types.AccountID, state any)
|
||||
|
||||
// OnAddPeer, when set, is called after AddPeer completes for a new account
|
||||
// (i.e. when a new client was actually created, not when an existing one
|
||||
@@ -167,9 +195,6 @@ type skipTLSVerifyContextKey struct{}
|
||||
// AddPeer registers a service for an account. If the account doesn't have a client yet,
|
||||
// one is created by authenticating with the management server using the provided token.
|
||||
// Multiple services can share the same client.
|
||||
//
|
||||
// Client creation (WG keygen, gRPC, embed.New) runs without holding clientsMux
|
||||
// so that concurrent AddPeer calls for different accounts execute in parallel.
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
|
||||
si := serviceInfo{serviceID: serviceID}
|
||||
|
||||
@@ -177,23 +202,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
ready := entry.ready
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// If the entry is still being initialized by another goroutine, wait.
|
||||
if ready != nil {
|
||||
select {
|
||||
case <-ready:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
if entry.initErr != nil {
|
||||
return fmt.Errorf("peer initialization failed: %w", entry.initErr)
|
||||
}
|
||||
}
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
@@ -210,43 +222,19 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert a placeholder so other goroutines calling AddPeer for the same
|
||||
// account will wait on the ready channel instead of starting a second
|
||||
// client creation.
|
||||
entry = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: si},
|
||||
ready: make(chan struct{}),
|
||||
}
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
createStart := time.Now()
|
||||
created, err := n.createClientEntry(ctx, accountID, key, authToken, si)
|
||||
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
|
||||
if n.OnAddPeer != nil {
|
||||
n.OnAddPeer(time.Since(createStart), err)
|
||||
}
|
||||
if err != nil {
|
||||
entry.initErr = err
|
||||
close(entry.ready)
|
||||
|
||||
n.clientsMux.Lock()
|
||||
delete(n.clients, accountID)
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Transfer any services that were registered by concurrent AddPeer calls
|
||||
// while we were creating the client.
|
||||
n.clientsMux.Lock()
|
||||
for k, v := range entry.services {
|
||||
created.services[k] = v
|
||||
}
|
||||
created.ready = nil
|
||||
n.clients[accountID] = created
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
close(entry.ready)
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
@@ -254,13 +242,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip.
|
||||
go n.runClientStartup(ctx, accountID, created.client)
|
||||
go n.runClientStartup(ctx, accountID, entry.client)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client.
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
|
||||
serviceID := si.serviceID
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -318,7 +306,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
ManagementURL: n.clientCfg.MgmtAddr,
|
||||
PrivateKey: privateKey.String(),
|
||||
LogLevel: log.WarnLevel.String(),
|
||||
BlockInbound: true,
|
||||
BlockInbound: n.clientCfg.BlockInbound,
|
||||
WireguardPort: &wgPort,
|
||||
PreSharedKey: n.clientCfg.PreSharedKey,
|
||||
})
|
||||
@@ -385,8 +373,25 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
|
||||
toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID})
|
||||
}
|
||||
}
|
||||
readyHandler := n.readyHandler
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
if readyHandler != nil {
|
||||
state := readyHandler(ctx, accountID, client)
|
||||
n.clientsMux.Lock()
|
||||
if e, ok := n.clients[accountID]; ok {
|
||||
e.inbound = state
|
||||
} else if state != nil && n.stopHandler != nil {
|
||||
// Account was removed while readyHandler ran; tear down the
|
||||
// resources it just brought up.
|
||||
stop := n.stopHandler
|
||||
n.clientsMux.Unlock()
|
||||
stop(accountID, state)
|
||||
n.clientsMux.Lock()
|
||||
}
|
||||
n.clientsMux.Unlock()
|
||||
}
|
||||
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
}
|
||||
@@ -432,11 +437,15 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
stopClient := len(entry.services) == 0
|
||||
var client *embed.Client
|
||||
var transport, insecureTransport *http.Transport
|
||||
var inbound any
|
||||
var stopHandler func(types.AccountID, any)
|
||||
if stopClient {
|
||||
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
|
||||
client = entry.client
|
||||
transport = entry.transport
|
||||
insecureTransport = entry.insecureTransport
|
||||
inbound = entry.inbound
|
||||
stopHandler = n.stopHandler
|
||||
delete(n.clients, accountID)
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -450,6 +459,9 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
|
||||
|
||||
if stopClient {
|
||||
if inbound != nil && stopHandler != nil {
|
||||
stopHandler(accountID, inbound)
|
||||
}
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
@@ -536,8 +548,12 @@ func (n *NetBird) StopAll(ctx context.Context) error {
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
|
||||
stopHandler := n.stopHandler
|
||||
var merr *multierror.Error
|
||||
for accountID, entry := range n.clients {
|
||||
if entry.inbound != nil && stopHandler != nil {
|
||||
stopHandler(accountID, entry.inbound)
|
||||
}
|
||||
entry.transport.CloseIdleConnections()
|
||||
entry.insecureTransport.CloseIdleConnections()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
@@ -590,6 +606,19 @@ func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
|
||||
return entry.client, true
|
||||
}
|
||||
|
||||
// IdentityForIP resolves a tunnel IP to a peer identity local to the given
|
||||
// account. Delegates to clientEntry.IdentityForIP. Returns ok=false when
|
||||
// the account has no client or the IP is not in its peerstore.
|
||||
func (n *NetBird) IdentityForIP(accountID types.AccountID, ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
n.clientsMux.RLock()
|
||||
entry, exists := n.clients[accountID]
|
||||
n.clientsMux.RUnlock()
|
||||
if !exists {
|
||||
return "", "", false
|
||||
}
|
||||
return entry.IdentityForIP(ip)
|
||||
}
|
||||
|
||||
// ListClientsForDebug returns information about all clients for debug purposes.
|
||||
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
|
||||
n.clientsMux.RLock()
|
||||
@@ -645,6 +674,18 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L
|
||||
}
|
||||
}
|
||||
|
||||
// SetClientLifecycle registers callbacks that run when an embedded
|
||||
// client becomes ready and when its entry is torn down. The opaque value
|
||||
// returned by ready is stored on the entry and handed back to stop on
|
||||
// cleanup. Must be called before AddPeer. A nil pair leaves the
|
||||
// outbound-only behaviour intact.
|
||||
func (n *NetBird) SetClientLifecycle(ready func(ctx context.Context, accountID types.AccountID, client *embed.Client) any, stop func(accountID types.AccountID, state any)) {
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
n.readyHandler = ready
|
||||
n.stopHandler = stop
|
||||
}
|
||||
|
||||
// dialWithTimeout wraps a DialContext function so that any dial timeout
|
||||
// stored in the context (via types.WithDialTimeout) is applied only to
|
||||
// the connection establishment phase, not the full request lifetime.
|
||||
@@ -687,3 +728,22 @@ func skipTLSVerifyFromContext(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
// directUpstreamContextKey signals that the request should bypass the embedded
|
||||
// NetBird WireGuard client and dial via the host's network stack instead.
|
||||
// Set by the reverse-proxy rewrite step when the matched target carries
|
||||
// PathTarget.DirectUpstream; consumed by MultiTransport.
|
||||
type directUpstreamContextKey struct{}
|
||||
|
||||
// WithDirectUpstream marks the context so MultiTransport routes the request
|
||||
// through its stdlib transport instead of the embedded NetBird roundtripper.
|
||||
func WithDirectUpstream(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, directUpstreamContextKey{}, true)
|
||||
}
|
||||
|
||||
// DirectUpstreamFromContext reports whether the context has been marked to
|
||||
// bypass the embedded NetBird client.
|
||||
func DirectUpstreamFromContext(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(directUpstreamContextKey{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package roundtrip
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -305,6 +306,36 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
assert.True(t, calls[0].connected)
|
||||
}
|
||||
|
||||
// TestNetBird_IdentityForIP_UnknownAccountReturnsFalse confirms that the
|
||||
// public lookup short-circuits when no client has been registered for
|
||||
// the queried account. The auth middleware uses ok=false as a fast deny.
|
||||
func TestNetBird_IdentityForIP_UnknownAccountReturnsFalse(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
_, _, ok := nb.IdentityForIP("acct-missing", netip.MustParseAddr("100.64.0.10"))
|
||||
assert.False(t, ok, "unknown account must yield ok=false")
|
||||
}
|
||||
|
||||
// TestClientEntry_IdentityForIP_NilClientGuard ensures the receiver
|
||||
// methods stay safe when called on partially-initialized state, which
|
||||
// can happen briefly during AddPeer setup or test fixtures.
|
||||
func TestClientEntry_IdentityForIP_NilClientGuard(t *testing.T) {
|
||||
var e *clientEntry
|
||||
_, _, ok := e.IdentityForIP(netip.MustParseAddr("100.64.0.10"))
|
||||
assert.False(t, ok, "nil clientEntry must yield ok=false")
|
||||
|
||||
e = &clientEntry{}
|
||||
_, _, ok = e.IdentityForIP(netip.MustParseAddr("100.64.0.10"))
|
||||
assert.False(t, ok, "clientEntry with nil embed.Client must yield ok=false")
|
||||
}
|
||||
|
||||
// TestClientEntry_IdentityForIP_InvalidIPReturnsFalse covers the input
|
||||
// guard so callers don't have to repeat the check.
|
||||
func TestClientEntry_IdentityForIP_InvalidIPReturnsFalse(t *testing.T) {
|
||||
e := &clientEntry{}
|
||||
_, _, ok := e.IdentityForIP(netip.Addr{})
|
||||
assert.False(t, ok, "invalid IP must yield ok=false")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
|
||||
|
||||
@@ -36,7 +36,7 @@ func BenchmarkPeekClientHello_TLS(b *testing.B) {
|
||||
for b.Loop() {
|
||||
r := bytes.NewReader(hello)
|
||||
conn := &readerConn{Reader: r}
|
||||
sni, wrapped, err := PeekClientHello(conn)
|
||||
sni, wrapped, _, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func BenchmarkPeekClientHello_NonTLS(b *testing.B) {
|
||||
for b.Loop() {
|
||||
r := bytes.NewReader(httpReq)
|
||||
conn := &readerConn{Reader: r}
|
||||
_, wrapped, err := PeekClientHello(conn)
|
||||
_, wrapped, _, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -100,28 +100,50 @@ type Router struct {
|
||||
// httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter.
|
||||
httpCh chan net.Conn
|
||||
httpListener *chanListener
|
||||
mu sync.RWMutex
|
||||
routes map[SNIHost][]Route
|
||||
fallback *Route
|
||||
draining bool
|
||||
dialResolve DialResolver
|
||||
activeConns sync.WaitGroup
|
||||
activeRelays sync.WaitGroup
|
||||
relaySem chan struct{}
|
||||
drainDone chan struct{}
|
||||
observer RelayObserver
|
||||
accessLog l4Logger
|
||||
geo restrict.GeoResolver
|
||||
// httpPlainCh feeds non-TLS HTTP connections to a parallel http.Server.
|
||||
// Set only when NewRouter is called with WithPlainHTTP option (used by
|
||||
// per-account inbound listeners that accept both :80 and :443 traffic).
|
||||
// Nil for the host SNI router and for port routers.
|
||||
httpPlainCh chan net.Conn
|
||||
httpPlainListener *chanListener
|
||||
mu sync.RWMutex
|
||||
routes map[SNIHost][]Route
|
||||
fallback *Route
|
||||
draining bool
|
||||
dialResolve DialResolver
|
||||
activeConns sync.WaitGroup
|
||||
activeRelays sync.WaitGroup
|
||||
relaySem chan struct{}
|
||||
drainDone chan struct{}
|
||||
observer RelayObserver
|
||||
accessLog l4Logger
|
||||
geo restrict.GeoResolver
|
||||
// svcCtxs tracks a context per service ID. All relay goroutines for a
|
||||
// service derive from its context; canceling it kills them immediately.
|
||||
svcCtxs map[types.ServiceID]context.Context
|
||||
svcCancels map[types.ServiceID]context.CancelFunc
|
||||
}
|
||||
|
||||
// RouterOption customises Router construction.
|
||||
type RouterOption func(*Router)
|
||||
|
||||
// WithPlainHTTP enables a parallel plain-HTTP channel on the router. When
|
||||
// set, connections whose first byte is not a TLS handshake are forwarded
|
||||
// to the plain channel returned by HTTPListenerPlain instead of the TLS
|
||||
// channel. Used by per-account inbound listeners that share both :80 and
|
||||
// :443 traffic on the same router.
|
||||
func WithPlainHTTP(addr net.Addr) RouterOption {
|
||||
return func(r *Router) {
|
||||
ch := make(chan net.Conn, httpChannelBuffer)
|
||||
r.httpPlainCh = ch
|
||||
r.httpPlainListener = newChanListener(ch, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// NewRouter creates a new SNI-based connection router.
|
||||
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router {
|
||||
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr, opts ...RouterOption) *Router {
|
||||
httpCh := make(chan net.Conn, httpChannelBuffer)
|
||||
return &Router{
|
||||
r := &Router{
|
||||
logger: logger,
|
||||
httpCh: httpCh,
|
||||
httpListener: newChanListener(httpCh, addr),
|
||||
@@ -131,6 +153,10 @@ func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Rou
|
||||
svcCtxs: make(map[types.ServiceID]context.Context),
|
||||
svcCancels: make(map[types.ServiceID]context.CancelFunc),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// NewPortRouter creates a Router for a dedicated port without an HTTP
|
||||
@@ -153,6 +179,16 @@ func (r *Router) HTTPListener() net.Listener {
|
||||
return r.httpListener
|
||||
}
|
||||
|
||||
// HTTPListenerPlain returns a net.Listener yielding non-TLS connections
|
||||
// for use with a parallel plain http.Server. Returns nil when the router
|
||||
// was not constructed with WithPlainHTTP.
|
||||
func (r *Router) HTTPListenerPlain() net.Listener {
|
||||
if r.httpPlainListener == nil {
|
||||
return nil
|
||||
}
|
||||
return r.httpPlainListener
|
||||
}
|
||||
|
||||
// AddRoute registers an SNI route. Multiple routes for the same host are
|
||||
// stored and resolved by priority at lookup time (HTTP > TCP).
|
||||
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
|
||||
@@ -254,6 +290,9 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
|
||||
if r.httpListener != nil {
|
||||
r.httpListener.Close()
|
||||
}
|
||||
if r.httpPlainListener != nil {
|
||||
r.httpPlainListener.Close()
|
||||
}
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
@@ -270,6 +309,7 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
|
||||
r.logger.Debugf("SNI router accept: %v", err)
|
||||
continue
|
||||
}
|
||||
r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
r.activeConns.Add(1)
|
||||
go func() {
|
||||
defer r.activeConns.Done()
|
||||
@@ -278,13 +318,24 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
|
||||
}
|
||||
}
|
||||
|
||||
// HandleConn lets external accept loops feed a connection through the
|
||||
// router's peek-and-dispatch logic. Use this when the same router serves
|
||||
// a secondary listener (for example, a per-account inbound :80 socket
|
||||
// alongside its :443 socket).
|
||||
func (r *Router) HandleConn(ctx context.Context, conn net.Conn) {
|
||||
r.activeConns.Add(1)
|
||||
defer r.activeConns.Done()
|
||||
r.handleConn(ctx, conn)
|
||||
}
|
||||
|
||||
// handleConn peeks at the TLS ClientHello and routes the connection.
|
||||
func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
// Fast path: when no SNI routes and no HTTP channel exist (pure TCP
|
||||
// fallback port), skip the TLS peek entirely to avoid read errors on
|
||||
// non-TLS connections and reduce latency.
|
||||
if r.isFallbackOnly() {
|
||||
r.handleUnmatched(ctx, conn)
|
||||
r.logger.Debugf("SNI router fallback-only mode for conn from %s; skipping ClientHello peek", conn.RemoteAddr())
|
||||
r.handleUnmatched(ctx, conn, false)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -294,11 +345,11 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
sni, wrapped, err := PeekClientHello(conn)
|
||||
sni, wrapped, isTLS, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
r.logger.Debugf("SNI peek: %v", err)
|
||||
r.logger.Debugf("SNI peek failed for conn from %s: %v", conn.RemoteAddr(), err)
|
||||
if wrapped != nil {
|
||||
r.handleUnmatched(ctx, wrapped)
|
||||
r.handleUnmatched(ctx, wrapped, isTLS)
|
||||
} else {
|
||||
_ = conn.Close()
|
||||
}
|
||||
@@ -313,13 +364,20 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
|
||||
host := SNIHost(strings.ToLower(sni))
|
||||
route, ok := r.lookupRoute(host)
|
||||
r.logger.WithFields(log.Fields{
|
||||
"remote": wrapped.RemoteAddr().String(),
|
||||
"sni": string(host),
|
||||
"match": ok,
|
||||
"tls": isTLS,
|
||||
}).Debug("SNI route lookup")
|
||||
if !ok {
|
||||
r.handleUnmatched(ctx, wrapped)
|
||||
r.handleUnmatched(ctx, wrapped, isTLS)
|
||||
return
|
||||
}
|
||||
|
||||
if route.Type == RouteHTTP {
|
||||
r.sendToHTTP(wrapped)
|
||||
r.logger.Debugf("SNI %q routed to HTTP handler (service_id=%s)", host, route.ServiceID)
|
||||
r.sendToHTTP(wrapped, isTLS)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -344,15 +402,17 @@ func (r *Router) isFallbackOnly() bool {
|
||||
}
|
||||
|
||||
// handleUnmatched routes a connection that didn't match any SNI route.
|
||||
// This includes ECH/ESNI connections where the cleartext SNI is empty.
|
||||
// This includes ECH/ESNI connections where the cleartext SNI is empty,
|
||||
// and plain (non-TLS) HTTP connections when isTLS is false.
|
||||
// It tries the fallback relay first, then the HTTP channel, and closes
|
||||
// the connection if neither is available.
|
||||
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
|
||||
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn, isTLS bool) {
|
||||
r.mu.RLock()
|
||||
fb := r.fallback
|
||||
r.mu.RUnlock()
|
||||
|
||||
if fb != nil {
|
||||
r.logger.Debugf("unmatched conn from %s relayed to TCP fallback (service_id=%s, target=%s)", conn.RemoteAddr(), fb.ServiceID, fb.Target)
|
||||
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
|
||||
if !errors.Is(err, errAccessRestricted) {
|
||||
r.logger.WithFields(log.Fields{
|
||||
@@ -364,7 +424,8 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
return
|
||||
}
|
||||
r.sendToHTTP(conn)
|
||||
r.logger.Debugf("unmatched conn from %s sent to HTTP channel (no TCP fallback configured)", conn.RemoteAddr())
|
||||
r.sendToHTTP(conn, isTLS)
|
||||
}
|
||||
|
||||
// lookupRoute returns the highest-priority route for the given SNI host.
|
||||
@@ -386,10 +447,20 @@ func (r *Router) lookupRoute(host SNIHost) (Route, bool) {
|
||||
}
|
||||
|
||||
// sendToHTTP feeds the connection to the HTTP handler via the channel.
|
||||
// If no HTTP channel is configured (port router), the router is
|
||||
// draining, or the channel is full, the connection is closed.
|
||||
func (r *Router) sendToHTTP(conn net.Conn) {
|
||||
if r.httpCh == nil {
|
||||
// When isTLS is false and a plain channel is configured the connection
|
||||
// is forwarded to the plain channel; otherwise it lands on the TLS
|
||||
// channel. If no usable channel exists, the router is draining, or the
|
||||
// channel is full, the connection is closed.
|
||||
func (r *Router) sendToHTTP(conn net.Conn, isTLS bool) {
|
||||
ch := r.httpCh
|
||||
chanName := "HTTP"
|
||||
if !isTLS && r.httpPlainCh != nil {
|
||||
ch = r.httpPlainCh
|
||||
chanName = "HTTP-plain"
|
||||
}
|
||||
|
||||
if ch == nil {
|
||||
r.logger.Debugf("%s channel nil; dropping conn from %s", chanName, conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
@@ -399,14 +470,15 @@ func (r *Router) sendToHTTP(conn net.Conn) {
|
||||
r.mu.RUnlock()
|
||||
|
||||
if draining {
|
||||
r.logger.Debugf("router draining; dropping conn from %s", conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case r.httpCh <- conn:
|
||||
case ch <- conn:
|
||||
default:
|
||||
r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr())
|
||||
r.logger.Warnf("%s channel full, dropping connection from %s", chanName, conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1739,3 +1739,97 @@ func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) {
|
||||
connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")}
|
||||
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR")
|
||||
}
|
||||
|
||||
// TestRouter_PlainHTTP_RoutesToPlainChannel verifies that a plain (non-TLS)
|
||||
// connection lands on the plain HTTP channel when the router was built
|
||||
// with WithPlainHTTP, leaving the TLS channel untouched.
|
||||
func TestRouter_PlainHTTP_RoutesToPlainChannel(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
|
||||
|
||||
router := NewRouter(logger, nil, addr, WithPlainHTTP(addr))
|
||||
router.AddRoute("example.com", Route{Type: RouteHTTP})
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "test listener bind must succeed")
|
||||
defer ln.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
_ = router.Serve(ctx, ln)
|
||||
}()
|
||||
|
||||
// Plain HTTP request (no TLS handshake byte).
|
||||
go func() {
|
||||
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
|
||||
}()
|
||||
|
||||
plainListener := router.HTTPListenerPlain()
|
||||
require.NotNil(t, plainListener, "plain listener must be exposed when WithPlainHTTP is set")
|
||||
|
||||
acceptDone := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
conn, err := plainListener.Accept()
|
||||
if err == nil {
|
||||
acceptDone <- conn
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case conn := <-acceptDone:
|
||||
require.NotNil(t, conn)
|
||||
_ = conn.Close()
|
||||
case <-router.HTTPListener().(*chanListener).ch:
|
||||
t.Fatal("plain HTTP request leaked into TLS channel")
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("plain HTTP connection never reached plain channel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled verifies that the
|
||||
// presence of a plain channel does not divert TLS traffic — TLS still
|
||||
// goes to the TLS channel as before.
|
||||
func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
|
||||
|
||||
router := NewRouter(logger, nil, addr, WithPlainHTTP(addr))
|
||||
router.AddRoute("example.com", Route{Type: RouteHTTP})
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "test listener bind must succeed")
|
||||
defer ln.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() { _ = router.Serve(ctx, ln) }()
|
||||
|
||||
// Send a TLS ClientHello.
|
||||
go func() {
|
||||
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tlsConn := tls.Client(conn, &tls.Config{
|
||||
ServerName: "example.com",
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
})
|
||||
_ = tlsConn.Handshake()
|
||||
_ = tlsConn.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case conn := <-router.httpCh:
|
||||
require.NotNil(t, conn, "TLS conn should land on the TLS channel")
|
||||
_ = conn.Close()
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("TLS conn never reached the TLS channel")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,26 +30,30 @@ const (
|
||||
// bytes transparently. If the data is not a valid TLS ClientHello or
|
||||
// contains no SNI extension, sni is empty and err is nil.
|
||||
//
|
||||
// isTLS reports whether the first byte indicated a TLS handshake record.
|
||||
// Callers can use this to distinguish plain (non-TLS) traffic from a TLS
|
||||
// stream that simply lacked an SNI extension or used ECH.
|
||||
//
|
||||
// ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the
|
||||
// real server name is encrypted inside the encrypted_client_hello
|
||||
// extension. This parser only reads the cleartext server_name extension
|
||||
// (type 0x0000), so ECH connections return sni="" and are routed through
|
||||
// the fallback path (or HTTP channel), which is the correct behavior
|
||||
// for a transparent proxy that does not terminate TLS.
|
||||
func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) {
|
||||
func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, isTLS bool, err error) {
|
||||
// Read the 5-byte TLS record header into a small stack-friendly buffer.
|
||||
var header [tlsRecordHeaderLen]byte
|
||||
if _, err := io.ReadFull(conn, header[:]); err != nil {
|
||||
return "", nil, fmt.Errorf("read TLS record header: %w", err)
|
||||
return "", nil, false, fmt.Errorf("read TLS record header: %w", err)
|
||||
}
|
||||
|
||||
if header[0] != contentTypeHandshake {
|
||||
return "", newPeekedConn(conn, header[:]), nil
|
||||
return "", newPeekedConn(conn, header[:]), false, nil
|
||||
}
|
||||
|
||||
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
||||
if recordLen == 0 || recordLen > maxClientHelloLen {
|
||||
return "", newPeekedConn(conn, header[:]), nil
|
||||
return "", newPeekedConn(conn, header[:]), true, nil
|
||||
}
|
||||
|
||||
// Single allocation for header + payload. The peekedConn takes
|
||||
@@ -59,11 +63,11 @@ func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) {
|
||||
|
||||
n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:])
|
||||
if err != nil {
|
||||
return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err)
|
||||
return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), true, fmt.Errorf("read TLS handshake payload: %w", err)
|
||||
}
|
||||
|
||||
sni = extractSNI(buf[tlsRecordHeaderLen:])
|
||||
return sni, newPeekedConn(conn, buf), nil
|
||||
return sni, newPeekedConn(conn, buf), true, nil
|
||||
}
|
||||
|
||||
// extractSNI parses a TLS handshake payload to find the SNI extension.
|
||||
|
||||
@@ -29,10 +29,11 @@ func TestPeekClientHello_ValidSNI(t *testing.T) {
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello")
|
||||
assert.NotNil(t, wrapped, "wrapped connection should not be nil")
|
||||
assert.True(t, isTLS, "TLS ClientHello should be flagged as TLS")
|
||||
|
||||
// Verify the wrapped connection replays the peeked bytes.
|
||||
// Read the first 5 bytes (TLS record header) to confirm replay.
|
||||
@@ -83,10 +84,11 @@ func TestPeekClientHello_MultipleSNIs(t *testing.T) {
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedSNI, sni)
|
||||
assert.NotNil(t, wrapped)
|
||||
assert.True(t, isTLS, "TLS handshake should be flagged as TLS")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -102,10 +104,11 @@ func TestPeekClientHello_NonTLSData(t *testing.T) {
|
||||
_, _ = clientConn.Write(httpData)
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sni, "should return empty SNI for non-TLS data")
|
||||
assert.NotNil(t, wrapped)
|
||||
assert.False(t, isTLS, "plain HTTP data should not be flagged as TLS")
|
||||
|
||||
// Verify the wrapped connection still provides the original data.
|
||||
buf := make([]byte, len(httpData))
|
||||
@@ -124,7 +127,7 @@ func TestPeekClientHello_TruncatedHeader(t *testing.T) {
|
||||
clientConn.Close()
|
||||
}()
|
||||
|
||||
_, _, err := PeekClientHello(serverConn)
|
||||
_, _, _, err := PeekClientHello(serverConn)
|
||||
assert.Error(t, err, "should error on truncated header")
|
||||
}
|
||||
|
||||
@@ -140,7 +143,7 @@ func TestPeekClientHello_TruncatedPayload(t *testing.T) {
|
||||
clientConn.Close()
|
||||
}()
|
||||
|
||||
_, _, err := PeekClientHello(serverConn)
|
||||
_, _, _, err := PeekClientHello(serverConn)
|
||||
assert.Error(t, err, "should error on truncated payload")
|
||||
}
|
||||
|
||||
@@ -154,10 +157,11 @@ func TestPeekClientHello_ZeroLengthRecord(t *testing.T) {
|
||||
_, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00})
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sni)
|
||||
assert.NotNil(t, wrapped)
|
||||
assert.True(t, isTLS, "zero-length record should still be a TLS handshake byte")
|
||||
}
|
||||
|
||||
func TestExtractSNI_InvalidPayload(t *testing.T) {
|
||||
|
||||
160
proxy/lifecycle.go
Normal file
160
proxy/lifecycle.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
)
|
||||
|
||||
// Config bundles every knob the proxy reads at construction time. It mirrors
|
||||
// the public fields on Server so library callers don't have to learn the
|
||||
// internal struct layout. Zero values mean "feature off" or "fall back to the
|
||||
// internal default" depending on the field — see the per-field doc.
|
||||
//
|
||||
// The standalone binary continues to populate Server fields directly, so
|
||||
// adding fields here must not change the zero-value behaviour of Server.
|
||||
type Config struct {
|
||||
// ListenAddr is the TCP address the main listener binds. Required.
|
||||
ListenAddr string
|
||||
// ID identifies this proxy instance to management. Empty value lets
|
||||
// New generate a timestamped default.
|
||||
ID string
|
||||
// Logger is the logrus logger used everywhere. Empty value falls back
|
||||
// to log.StandardLogger().
|
||||
Logger *log.Logger
|
||||
// Version is the build version string reported to management. Empty
|
||||
// becomes "dev".
|
||||
Version string
|
||||
// ProxyURL is the public address operators use to reach this proxy.
|
||||
ProxyURL string
|
||||
// ManagementAddress is the gRPC URL of the management server.
|
||||
ManagementAddress string
|
||||
// ProxyToken authenticates this proxy with the management server.
|
||||
ProxyToken string
|
||||
|
||||
// CertificateDirectory is the directory holding TLS certificate
|
||||
// material (static or ACME-provisioned).
|
||||
CertificateDirectory string
|
||||
// CertificateFile is the certificate filename within
|
||||
// CertificateDirectory.
|
||||
CertificateFile string
|
||||
// CertificateKeyFile is the private key filename within
|
||||
// CertificateDirectory.
|
||||
CertificateKeyFile string
|
||||
// GenerateACMECertificates toggles ACME certificate provisioning.
|
||||
GenerateACMECertificates bool
|
||||
// ACMEChallengeAddress is the listen address for HTTP-01 challenges.
|
||||
ACMEChallengeAddress string
|
||||
// ACMEDirectory is the ACME directory URL (Let's Encrypt by default).
|
||||
ACMEDirectory string
|
||||
// ACMEEABKID is the External Account Binding Key ID for CAs that
|
||||
// require EAB (e.g. ZeroSSL).
|
||||
ACMEEABKID string
|
||||
// ACMEEABHMACKey is the External Account Binding HMAC key for CAs
|
||||
// that require EAB.
|
||||
ACMEEABHMACKey string
|
||||
// ACMEChallengeType is the ACME challenge type ("tls-alpn-01" or
|
||||
// "http-01"). Empty defaults to "tls-alpn-01".
|
||||
ACMEChallengeType string
|
||||
// CertLockMethod controls how ACME certificate locks are coordinated
|
||||
// across replicas.
|
||||
CertLockMethod acme.CertLockMethod
|
||||
// WildcardCertDir is an optional directory containing static wildcard
|
||||
// certificates that override ACME for matching domains.
|
||||
WildcardCertDir string
|
||||
|
||||
// DebugEndpointEnabled toggles the debug HTTP endpoint.
|
||||
DebugEndpointEnabled bool
|
||||
// DebugEndpointAddress is the bind address for the debug endpoint.
|
||||
DebugEndpointAddress string
|
||||
// HealthAddr is the bind address for the health probe and metrics
|
||||
// surface. Empty disables the health probe entirely (library callers
|
||||
// can attach their own).
|
||||
HealthAddr string
|
||||
|
||||
// ForwardedProto overrides the X-Forwarded-Proto value sent to
|
||||
// backends. Valid values: "auto", "http", "https".
|
||||
ForwardedProto string
|
||||
// TrustedProxies is a list of IP prefixes for trusted upstream
|
||||
// proxies that may set forwarding headers.
|
||||
TrustedProxies []netip.Prefix
|
||||
// WireguardPort is the UDP port for the embedded NetBird tunnel.
|
||||
// Zero asks the OS for a random port.
|
||||
WireguardPort uint16
|
||||
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
|
||||
ProxyProtocol bool
|
||||
// PreSharedKey is the WireGuard pre-shared key used between the
|
||||
// proxy's embedded clients and peers.
|
||||
PreSharedKey string
|
||||
|
||||
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
|
||||
// ports for TCP/UDP/TLS services.
|
||||
SupportsCustomPorts bool
|
||||
// RequireSubdomain forces accounts to use a subdomain in front of
|
||||
// the proxy's cluster domain.
|
||||
RequireSubdomain bool
|
||||
// Private flags this proxy as embedded in a netbird client and
|
||||
// serving exclusively over the WireGuard tunnel. Also enables
|
||||
// per-account inbound listeners on each embedded client's netstack.
|
||||
Private bool
|
||||
|
||||
// MaxDialTimeout caps the per-service backend dial timeout.
|
||||
MaxDialTimeout time.Duration
|
||||
// MaxSessionIdleTimeout caps the per-service session idle timeout.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
|
||||
// GeoDataDir is the directory containing GeoLite2 MMDB files.
|
||||
GeoDataDir string
|
||||
// CrowdSecAPIURL is the CrowdSec LAPI URL. Empty disables CrowdSec.
|
||||
CrowdSecAPIURL string
|
||||
// CrowdSecAPIKey is the CrowdSec bouncer API key. Empty disables
|
||||
// CrowdSec.
|
||||
CrowdSecAPIKey string
|
||||
}
|
||||
|
||||
// New builds a Server from cfg without performing any I/O. No goroutines
|
||||
// are spawned, no network connections are dialed, and no listeners are
|
||||
// bound — call Start to bring the proxy up. Returning a fully-formed
|
||||
// Server keeps the standalone code path (which still constructs Server
|
||||
// directly) byte-for-byte equivalent.
|
||||
func New(cfg Config) *Server {
|
||||
return &Server{
|
||||
ListenAddr: cfg.ListenAddr,
|
||||
ID: cfg.ID,
|
||||
Logger: cfg.Logger,
|
||||
Version: cfg.Version,
|
||||
ProxyURL: cfg.ProxyURL,
|
||||
ManagementAddress: cfg.ManagementAddress,
|
||||
ProxyToken: cfg.ProxyToken,
|
||||
CertificateDirectory: cfg.CertificateDirectory,
|
||||
CertificateFile: cfg.CertificateFile,
|
||||
CertificateKeyFile: cfg.CertificateKeyFile,
|
||||
GenerateACMECertificates: cfg.GenerateACMECertificates,
|
||||
ACMEChallengeAddress: cfg.ACMEChallengeAddress,
|
||||
ACMEDirectory: cfg.ACMEDirectory,
|
||||
ACMEEABKID: cfg.ACMEEABKID,
|
||||
ACMEEABHMACKey: cfg.ACMEEABHMACKey,
|
||||
ACMEChallengeType: cfg.ACMEChallengeType,
|
||||
CertLockMethod: cfg.CertLockMethod,
|
||||
WildcardCertDir: cfg.WildcardCertDir,
|
||||
DebugEndpointEnabled: cfg.DebugEndpointEnabled,
|
||||
DebugEndpointAddress: cfg.DebugEndpointAddress,
|
||||
HealthAddress: cfg.HealthAddr,
|
||||
ForwardedProto: cfg.ForwardedProto,
|
||||
TrustedProxies: cfg.TrustedProxies,
|
||||
WireguardPort: cfg.WireguardPort,
|
||||
ProxyProtocol: cfg.ProxyProtocol,
|
||||
PreSharedKey: cfg.PreSharedKey,
|
||||
SupportsCustomPorts: cfg.SupportsCustomPorts,
|
||||
RequireSubdomain: cfg.RequireSubdomain,
|
||||
Private: cfg.Private,
|
||||
MaxDialTimeout: cfg.MaxDialTimeout,
|
||||
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
|
||||
GeoDataDir: cfg.GeoDataDir,
|
||||
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
|
||||
CrowdSecAPIKey: cfg.CrowdSecAPIKey,
|
||||
}
|
||||
}
|
||||
134
proxy/lifecycle_test.go
Normal file
134
proxy/lifecycle_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// quietLifecycleLogger keeps lifecycle tests from spamming the test output.
|
||||
func quietLifecycleLogger() *log.Logger {
|
||||
l := log.New()
|
||||
l.SetOutput(io.Discard)
|
||||
l.SetLevel(log.PanicLevel)
|
||||
return l
|
||||
}
|
||||
|
||||
func TestNewIsPureConstructor(t *testing.T) {
|
||||
cfg := Config{
|
||||
ListenAddr: ":0",
|
||||
ID: "test-id",
|
||||
Logger: quietLifecycleLogger(),
|
||||
Version: "test",
|
||||
ManagementAddress: "https://example.invalid",
|
||||
HealthAddr: "",
|
||||
ForwardedProto: "auto",
|
||||
}
|
||||
|
||||
srv := New(cfg)
|
||||
require.NotNil(t, srv, "New must return a non-nil Server")
|
||||
|
||||
assert.Equal(t, ":0", srv.ListenAddr, "ListenAddr should round-trip")
|
||||
assert.Equal(t, "test-id", srv.ID, "ID should round-trip")
|
||||
assert.Equal(t, "test", srv.Version, "Version should round-trip")
|
||||
assert.Equal(t, "https://example.invalid", srv.ManagementAddress, "ManagementAddress should round-trip")
|
||||
assert.Equal(t, "auto", srv.ForwardedProto, "ForwardedProto should round-trip")
|
||||
|
||||
// Pure constructor: no goroutines, no listener bind, no management dial.
|
||||
assert.False(t, srv.started, "Server must be marked unstarted before Start")
|
||||
assert.Nil(t, srv.mgmtClient, "mgmt client must not be created in New")
|
||||
assert.Nil(t, srv.netbird, "netbird client must not be created in New")
|
||||
assert.Nil(t, srv.https, "https server must not be created in New")
|
||||
assert.Nil(t, srv.healthServer, "health server must not be created in New")
|
||||
assert.Nil(t, srv.runCancel, "runCancel must be nil before Start")
|
||||
assert.Nil(t, srv.runErrCh, "runErrCh must be nil before Start")
|
||||
}
|
||||
|
||||
func TestStopBeforeStartIsNoOp(t *testing.T) {
|
||||
srv := New(Config{Logger: quietLifecycleLogger()})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := srv.Stop(ctx)
|
||||
assert.NoError(t, err, "Stop on an unstarted server must succeed without error")
|
||||
|
||||
err = srv.Stop(ctx)
|
||||
assert.NoError(t, err, "Stop must remain idempotent across repeated calls")
|
||||
}
|
||||
|
||||
func TestStartFailsWithoutManagement(t *testing.T) {
|
||||
srv := New(Config{
|
||||
Logger: quietLifecycleLogger(),
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
ManagementAddress: "://broken-url",
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := srv.Start(ctx)
|
||||
require.Error(t, err, "Start must surface management dial failures")
|
||||
|
||||
assert.True(t, srv.started, "started flag is set before any dial attempt so a second Start fails fast")
|
||||
|
||||
err = srv.Start(ctx)
|
||||
require.Error(t, err, "second Start must reject")
|
||||
assert.Contains(t, err.Error(), "already started", "error must explain why the call was rejected")
|
||||
}
|
||||
|
||||
func TestStopIsIdempotent(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: quietLifecycleLogger(),
|
||||
started: true,
|
||||
runErrCh: make(chan struct{}),
|
||||
runCancel: func() {},
|
||||
}
|
||||
srv.recordRunErr(errors.New("synthetic"))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := srv.Stop(ctx)
|
||||
require.Error(t, err, "Stop must surface the recorded background error")
|
||||
assert.Contains(t, err.Error(), "synthetic", "error must round-trip recordRunErr's value")
|
||||
|
||||
err = srv.Stop(ctx)
|
||||
require.Error(t, err, "second Stop must still report the same error")
|
||||
assert.Contains(t, err.Error(), "synthetic", "idempotent Stop must return the cached error")
|
||||
}
|
||||
|
||||
func TestRecordRunErrPreservesFirstFailure(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: quietLifecycleLogger(),
|
||||
runErrCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
srv.recordRunErr(errors.New("first"))
|
||||
srv.recordRunErr(errors.New("second"))
|
||||
|
||||
require.Error(t, srv.runErr, "first failure must be retained")
|
||||
assert.Contains(t, srv.runErr.Error(), "first", "second call must not overwrite the cached error")
|
||||
|
||||
select {
|
||||
case <-srv.runErrCh:
|
||||
default:
|
||||
t.Fatal("recordRunErr must close runErrCh so waitAndStop unblocks")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopSkipsShutdownWhenNeverStarted(t *testing.T) {
|
||||
srv := New(Config{Logger: quietLifecycleLogger()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := srv.Stop(ctx)
|
||||
assert.NoError(t, err, "Stop on an unstarted server should not block on the cancelled ctx")
|
||||
}
|
||||
@@ -239,6 +239,10 @@ func (m *testProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
@@ -565,6 +569,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
|
||||
proxytypes.AccountID(mapping.GetAccountId()),
|
||||
proxytypes.ServiceID(mapping.GetId()),
|
||||
nil,
|
||||
mapping.GetPrivate(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
600
proxy/server.go
600
proxy/server.go
@@ -37,6 +37,7 @@ import (
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
@@ -114,9 +115,28 @@ type Server struct {
|
||||
// The mapping worker waits on this before processing updates.
|
||||
routerReady chan struct{}
|
||||
|
||||
// inbound, when non-nil, manages per-account inbound listeners. Set by
|
||||
// initPrivateInbound only when Private is true so the standalone
|
||||
// proxy keeps its zero-overhead default path.
|
||||
inbound *inboundManager
|
||||
|
||||
// Lifecycle state — populated by Start, consumed by Stop. The fields
|
||||
// stay zero on a fresh Server until Start runs so direct struct
|
||||
// construction (`&Server{...}`) used by tests still works.
|
||||
runCancel context.CancelFunc
|
||||
mgmtConn *grpc.ClientConn
|
||||
runErr error
|
||||
runErrCh chan struct{}
|
||||
startMu sync.Mutex
|
||||
started bool
|
||||
stopOnce sync.Once
|
||||
|
||||
// Mostly used for debugging on management.
|
||||
startTime time.Time
|
||||
|
||||
// ListenAddr is the address the main TCP listener binds. Populated by
|
||||
// New from Config or by ListenAndServe from its addr argument.
|
||||
ListenAddr string
|
||||
ID string
|
||||
Logger *log.Logger
|
||||
Version string
|
||||
@@ -177,6 +197,14 @@ type Server struct {
|
||||
// in front of this proxy's cluster domain. When true, accounts cannot
|
||||
// create services on the bare cluster domain.
|
||||
RequireSubdomain bool
|
||||
// Private flags this proxy as embedded in a netbird client and serving
|
||||
// exclusively over the WireGuard tunnel (i.e. `netbird proxy`). Reported
|
||||
// upstream as a capability so dashboards can distinguish per-peer
|
||||
// clusters from centralised ones, and turns on per-account inbound
|
||||
// listeners on each embedded client's netstack: every account that
|
||||
// registers a service exposes :80 + :443 inside its own WG tunnel,
|
||||
// scoped to that account's services only.
|
||||
Private bool
|
||||
// MaxDialTimeout caps the per-service backend dial timeout.
|
||||
// When the API sends a timeout, it is clamped to this value.
|
||||
// When the API sends no timeout, this value is used as the default.
|
||||
@@ -222,12 +250,16 @@ func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, se
|
||||
status = proto.ProxyStatus_PROXY_STATUS_ACTIVE
|
||||
}
|
||||
|
||||
_, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{
|
||||
req := &proto.SendStatusUpdateRequest{
|
||||
ServiceId: string(serviceID),
|
||||
AccountId: string(accountID),
|
||||
Status: status,
|
||||
CertificateIssued: false,
|
||||
})
|
||||
}
|
||||
if connected {
|
||||
req.InboundListener = s.inboundListenerProto(accountID)
|
||||
}
|
||||
_, err := s.mgmtClient.SendStatusUpdate(ctx, req)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -238,56 +270,68 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.Ac
|
||||
AccountId: string(accountID),
|
||||
Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE,
|
||||
CertificateIssued: true,
|
||||
InboundListener: s.inboundListenerProto(accountID),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.initDefaults()
|
||||
s.routerReady = make(chan struct{})
|
||||
s.udpRelays = make(map[types.ServiceID]*udprelay.Relay)
|
||||
s.portRouters = make(map[uint16]*portRouter)
|
||||
s.svcPorts = make(map[types.ServiceID][]uint16)
|
||||
s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping)
|
||||
|
||||
exporter, err := prometheus.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create prometheus exporter: %w", err)
|
||||
// inboundListenerProto resolves the per-account inbound listener state for
|
||||
// the SendStatusUpdate payload. Returns nil when --private-inbound is off
|
||||
// or the account has no live listener so management treats the field as
|
||||
// absent.
|
||||
func (s *Server) inboundListenerProto(accountID types.AccountID) *proto.ProxyInboundListener {
|
||||
if s.inbound == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
||||
pkg := reflect.TypeOf(Server{}).PkgPath()
|
||||
meter := provider.Meter(pkg)
|
||||
|
||||
s.meter, err = proxymetrics.New(ctx, meter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create metrics: %w", err)
|
||||
info, ok := s.inbound.ListenerInfo(accountID)
|
||||
if !ok || info.TunnelIP == "" {
|
||||
return nil
|
||||
}
|
||||
return &proto.ProxyInboundListener{
|
||||
TunnelIp: info.TunnelIP,
|
||||
HttpsPort: uint32(info.HTTPSPort),
|
||||
HttpPort: uint32(info.HTTPPort),
|
||||
}
|
||||
}
|
||||
|
||||
mgmtConn, err := s.dialManagement()
|
||||
if err != nil {
|
||||
// ListenAndServe is the standalone entrypoint. It binds the listener, runs
|
||||
// the proxy until ctx is cancelled or a background goroutine fails, then
|
||||
// drains and stops. Library callers should prefer New + Start + Stop and
|
||||
// own their own shutdown signalling.
|
||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
|
||||
s.ListenAddr = addr
|
||||
if err := s.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := mgmtConn.Close(); err != nil {
|
||||
s.Logger.Debugf("management connection close: %v", err)
|
||||
}
|
||||
}()
|
||||
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
||||
return s.waitAndStop(ctx)
|
||||
}
|
||||
|
||||
// Start brings the proxy up: dials management, configures TLS, binds the
|
||||
// main listener, and spawns the SNI router and HTTPS server goroutines. It
|
||||
// returns once the listener is bound; background errors are surfaced
|
||||
// through Stop's return value. Start is not safe to call twice.
|
||||
func (s *Server) Start(ctx context.Context) error {
|
||||
s.startMu.Lock()
|
||||
if s.started {
|
||||
s.startMu.Unlock()
|
||||
return errors.New("proxy already started")
|
||||
}
|
||||
s.started = true
|
||||
s.startMu.Unlock()
|
||||
|
||||
s.initLifecycleState()
|
||||
if err := s.initMetrics(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.initManagementClient(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
defer runCancel()
|
||||
s.runCancel = runCancel
|
||||
|
||||
// Initialize the netbird client, this is required to build peer connections
|
||||
// to proxy over.
|
||||
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
|
||||
MgmtAddr: s.ManagementAddress,
|
||||
WGPort: s.WireguardPort,
|
||||
PreSharedKey: s.PreSharedKey,
|
||||
}, s.Logger, s, s.mgmtClient)
|
||||
s.netbird.OnAddPeer = s.meter.RecordAddPeerDuration
|
||||
|
||||
// Create health checker before the mapping worker so it can track
|
||||
// management connectivity from the first stream connection.
|
||||
s.initNetBirdClient()
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger))
|
||||
@@ -300,34 +344,25 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
||||
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
|
||||
s.initReverseProxy()
|
||||
|
||||
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize geolocation: %w", err)
|
||||
}
|
||||
s.geoRaw = geoLookup
|
||||
if geoLookup != nil {
|
||||
s.geo = geoLookup
|
||||
if err := s.initGeoLookup(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var startupOK bool
|
||||
startupOK := false
|
||||
defer func() {
|
||||
if startupOK {
|
||||
return
|
||||
}
|
||||
if s.geoRaw != nil {
|
||||
if err := s.geoRaw.Close(); err != nil {
|
||||
s.Logger.Debugf("close geolocation on startup failure: %v", err)
|
||||
if closeErr := s.geoRaw.Close(); closeErr != nil {
|
||||
s.Logger.Debugf("close geolocation on startup failure: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Configure the authentication middleware with session validator for OIDC group checks.
|
||||
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
|
||||
|
||||
// Configure Access logs to management server.
|
||||
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
|
||||
|
||||
s.startDebugEndpoint()
|
||||
@@ -336,35 +371,21 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the handler chain from inside out.
|
||||
handler := http.Handler(s.proxy)
|
||||
handler = s.auth.Protect(handler)
|
||||
handler = web.AssetHandler(handler)
|
||||
handler = s.accessLog.Middleware(handler)
|
||||
handler = s.meter.Middleware(handler)
|
||||
handler = s.hijackTracker.Middleware(handler)
|
||||
handler := s.buildHandlerChain()
|
||||
s.initPrivateInbound(handler, tlsConfig)
|
||||
|
||||
// Start a raw TCP listener; the SNI router peeks at ClientHello
|
||||
// and routes to either the HTTP handler or a TCP relay.
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "tcp", addr)
|
||||
ln, err := s.bindMainListener(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||
return err
|
||||
}
|
||||
if s.ProxyProtocol {
|
||||
ln = s.wrapProxyProtocol(ln)
|
||||
}
|
||||
s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid
|
||||
|
||||
// Set up the SNI router for TCP/HTTP multiplexing on the main port.
|
||||
s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr())
|
||||
s.mainRouter.SetObserver(s.meter)
|
||||
s.mainRouter.SetAccessLogger(s.accessLog)
|
||||
close(s.routerReady)
|
||||
|
||||
// The HTTP server uses the chanListener fed by the SNI router.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Addr: s.ListenAddr,
|
||||
Handler: handler,
|
||||
TLSConfig: tlsConfig,
|
||||
ReadHeaderTimeout: httpReadHeaderTimeout,
|
||||
@@ -374,35 +395,201 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
|
||||
startupOK = true
|
||||
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
|
||||
httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "")
|
||||
if serveErr := s.https.ServeTLS(s.mainRouter.HTTPListener(), "", ""); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
||||
s.recordRunErr(fmt.Errorf("https server: %w", serveErr))
|
||||
}
|
||||
}()
|
||||
|
||||
routerErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debugf("starting SNI router on %s", addr)
|
||||
routerErr <- s.mainRouter.Serve(runCtx, ln)
|
||||
s.Logger.Debugf("starting SNI router on %s", s.ListenAddr)
|
||||
if serveErr := s.mainRouter.Serve(runCtx, ln); serveErr != nil {
|
||||
s.recordRunErr(fmt.Errorf("SNI router: %w", serveErr))
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop drains in-flight connections, shuts down all background services,
|
||||
// and releases resources. Idempotent; calling it before Start is a no-op.
|
||||
// Returns the first fatal error reported by a background goroutine, if
|
||||
// any. The provided ctx bounds the total wait time; once it is cancelled
|
||||
// Stop returns even if drain is still in flight.
|
||||
func (s *Server) Stop(ctx context.Context) error {
|
||||
s.stopOnce.Do(func() {
|
||||
s.startMu.Lock()
|
||||
started := s.started
|
||||
s.startMu.Unlock()
|
||||
if !started {
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
s.gracefulShutdown()
|
||||
if s.runCancel != nil {
|
||||
s.runCancel()
|
||||
}
|
||||
if s.mgmtConn != nil {
|
||||
if err := s.mgmtConn.Close(); err != nil {
|
||||
s.Logger.Debugf("management connection close: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
s.Logger.Warnf("proxy stop deadline exceeded: %v", ctx.Err())
|
||||
}
|
||||
})
|
||||
|
||||
s.startMu.Lock()
|
||||
defer s.startMu.Unlock()
|
||||
return s.runErr
|
||||
}
|
||||
|
||||
// waitAndStop blocks until ctx is cancelled or a background goroutine
|
||||
// reports a fatal error, then drains and stops. Used by ListenAndServe.
|
||||
func (s *Server) waitAndStop(ctx context.Context) error {
|
||||
select {
|
||||
case err := <-httpsErr:
|
||||
s.shutdownServices()
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("https server: %w", err)
|
||||
}
|
||||
return nil
|
||||
case err := <-routerErr:
|
||||
s.shutdownServices()
|
||||
if err != nil {
|
||||
return fmt.Errorf("SNI router: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
s.gracefulShutdown()
|
||||
return nil
|
||||
case <-s.runErrCh:
|
||||
}
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout+shutdownServiceTimeout)
|
||||
defer cancel()
|
||||
return s.Stop(stopCtx)
|
||||
}
|
||||
|
||||
// recordRunErr stores the first fatal background error and signals
|
||||
// waitAndStop. Subsequent errors are logged at debug level so the first
|
||||
// cause is preserved.
|
||||
func (s *Server) recordRunErr(err error) {
|
||||
s.startMu.Lock()
|
||||
defer s.startMu.Unlock()
|
||||
if s.runErr != nil {
|
||||
s.Logger.Debugf("background error after first failure: %v", err)
|
||||
return
|
||||
}
|
||||
s.runErr = err
|
||||
if s.runErrCh != nil {
|
||||
close(s.runErrCh)
|
||||
}
|
||||
}
|
||||
|
||||
// initLifecycleState seeds the maps and channels Start needs to wire up
|
||||
// background goroutines. Called once at the top of Start.
|
||||
func (s *Server) initLifecycleState() {
|
||||
s.initDefaults()
|
||||
s.routerReady = make(chan struct{})
|
||||
s.runErrCh = make(chan struct{})
|
||||
s.udpRelays = make(map[types.ServiceID]*udprelay.Relay)
|
||||
s.portRouters = make(map[uint16]*portRouter)
|
||||
s.svcPorts = make(map[types.ServiceID][]uint16)
|
||||
s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping)
|
||||
}
|
||||
|
||||
// initMetrics builds the prometheus exporter and meter bundle.
|
||||
func (s *Server) initMetrics(ctx context.Context) error {
|
||||
exporter, err := prometheus.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create prometheus exporter: %w", err)
|
||||
}
|
||||
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
||||
pkg := reflect.TypeOf(Server{}).PkgPath()
|
||||
meter := provider.Meter(pkg)
|
||||
s.meter, err = proxymetrics.New(ctx, meter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create metrics: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// initManagementClient dials management and stashes the connection so
|
||||
// Stop can close it deterministically.
|
||||
func (s *Server) initManagementClient() error {
|
||||
conn, err := s.dialManagement()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.mgmtConn = conn
|
||||
s.mgmtClient = proto.NewProxyServiceClient(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// initNetBirdClient builds the multi-tenant embedded NetBird client used
|
||||
// for outbound RoundTripping and (when --private-inbound is on) per-account
|
||||
// inbound listeners.
|
||||
func (s *Server) initNetBirdClient() {
|
||||
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
|
||||
MgmtAddr: s.ManagementAddress,
|
||||
WGPort: s.WireguardPort,
|
||||
PreSharedKey: s.PreSharedKey,
|
||||
// On --private the embedded client serves per-account inbound
|
||||
// listeners and must apply management's ACL: keep BlockInbound off
|
||||
// so the engine creates the ACL manager. On the standalone proxy
|
||||
// the embedded client never accepts inbound, so block.
|
||||
BlockInbound: !s.Private,
|
||||
}, s.Logger, s, s.mgmtClient)
|
||||
s.netbird.OnAddPeer = s.meter.RecordAddPeerDuration
|
||||
}
|
||||
|
||||
// initReverseProxy builds the meter-instrumented reverse proxy. MultiTransport
|
||||
// routes targets opted into direct_upstream through the host's network stack
|
||||
// (stdlib transport); everything else falls through to the embedded NetBird
|
||||
// client. The split is needed so direct_upstream targets resolve DNS via the
|
||||
// proxy host's resolver instead of the tunnel's DNS.
|
||||
func (s *Server) initReverseProxy() {
|
||||
upstreamRT := roundtrip.NewMultiTransport(s.netbird)
|
||||
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(upstreamRT), s.ForwardedProto, s.TrustedProxies, s.Logger)
|
||||
}
|
||||
|
||||
// initGeoLookup configures the GeoLite2 lookup used for country-based
|
||||
// access restrictions and access-log enrichment.
|
||||
func (s *Server) initGeoLookup() error {
|
||||
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize geolocation: %w", err)
|
||||
}
|
||||
s.geoRaw = geoLookup
|
||||
if geoLookup != nil {
|
||||
s.geo = geoLookup
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildHandlerChain wires the request middlewares from inside out.
|
||||
func (s *Server) buildHandlerChain() http.Handler {
|
||||
handler := http.Handler(s.proxy)
|
||||
handler = s.auth.Protect(handler)
|
||||
handler = web.AssetHandler(handler)
|
||||
handler = s.accessLog.Middleware(handler)
|
||||
handler = s.meter.Middleware(handler)
|
||||
return s.hijackTracker.Middleware(handler)
|
||||
}
|
||||
|
||||
// bindMainListener binds the main TCP listener and wraps it with PROXY
|
||||
// protocol when configured.
|
||||
func (s *Server) bindMainListener(ctx context.Context) (net.Listener, error) {
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "tcp", s.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen on %s: %w", s.ListenAddr, err)
|
||||
}
|
||||
if s.ProxyProtocol {
|
||||
ln = s.wrapProxyProtocol(ln)
|
||||
}
|
||||
s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"requested_addr": s.ListenAddr,
|
||||
"bound_addr": ln.Addr().String(),
|
||||
"private": s.Private,
|
||||
"proxy_protocol": s.ProxyProtocol,
|
||||
}).Info("proxy main listener bound")
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
// initDefaults sets fallback values for optional Server fields.
|
||||
@@ -434,6 +621,9 @@ func (s *Server) startDebugEndpoint() {
|
||||
if s.acme != nil {
|
||||
debugHandler.SetCertStatus(s.acme)
|
||||
}
|
||||
if s.inbound != nil {
|
||||
debugHandler.SetInboundProvider(inboundDebugAdapter{mgr: s.inbound})
|
||||
}
|
||||
s.debug = &http.Server{
|
||||
Addr: debugAddr,
|
||||
Handler: debugHandler,
|
||||
@@ -447,16 +637,18 @@ func (s *Server) startDebugEndpoint() {
|
||||
}()
|
||||
}
|
||||
|
||||
// startHealthServer launches the health probe and metrics server.
|
||||
// startHealthServer launches the health probe and metrics server. Empty
|
||||
// HealthAddress disables the probe entirely (intended for library callers
|
||||
// that want to manage their own health surface).
|
||||
func (s *Server) startHealthServer() error {
|
||||
healthAddr := s.HealthAddress
|
||||
if healthAddr == "" {
|
||||
healthAddr = defaultHealthAddr
|
||||
if s.HealthAddress == "" {
|
||||
s.Logger.Debug("health probe disabled (empty HealthAddress)")
|
||||
return nil
|
||||
}
|
||||
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
|
||||
healthListener, err := net.Listen("tcp", healthAddr)
|
||||
s.healthServer = health.NewServer(s.HealthAddress, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
|
||||
healthListener, err := net.Listen("tcp", s.HealthAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
|
||||
return fmt.Errorf("health probe server listen on %s: %w", s.HealthAddress, err)
|
||||
}
|
||||
go func() {
|
||||
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
@@ -507,8 +699,9 @@ func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxypr
|
||||
}
|
||||
|
||||
const (
|
||||
defaultHealthAddr = "localhost:8080"
|
||||
defaultDebugAddr = "localhost:8444"
|
||||
// defaultDebugAddr is the localhost-bound fallback for the debug endpoint
|
||||
// when DebugEndpointAddress is empty.
|
||||
defaultDebugAddr = "localhost:8444"
|
||||
|
||||
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
|
||||
// header after accepting a connection.
|
||||
@@ -661,8 +854,10 @@ func (s *Server) gracefulShutdown() {
|
||||
defer drainCancel()
|
||||
|
||||
s.Logger.Info("draining in-flight connections")
|
||||
if err := s.https.Shutdown(drainCtx); err != nil {
|
||||
s.Logger.Warnf("https server drain: %v", err)
|
||||
if s.https != nil {
|
||||
if err := s.https.Shutdown(drainCtx); err != nil {
|
||||
s.Logger.Warnf("https server drain: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
|
||||
@@ -809,6 +1004,18 @@ func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFu
|
||||
return client.DialContext, nil
|
||||
}
|
||||
|
||||
// initPrivateInbound wires per-account inbound listeners when --private
|
||||
// is set. When the flag is off this is a no-op and the standalone proxy keeps
|
||||
// its byte-for-byte previous behaviour.
|
||||
func (s *Server) initPrivateInbound(handler http.Handler, tlsConfig *tls.Config) {
|
||||
if !s.Private {
|
||||
return
|
||||
}
|
||||
s.inbound = newInboundManager(s.Logger, handler, tlsConfig)
|
||||
s.netbird.SetClientLifecycle(s.inbound.onClientReady, s.inbound.onClientStop)
|
||||
s.Logger.Info("private inbound listeners enabled (per-account :80 + :443)")
|
||||
}
|
||||
|
||||
// notifyError reports a resource error back to management so it can be
|
||||
// surfaced to the user (e.g. port bind failure, dialer resolution error).
|
||||
func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) {
|
||||
@@ -942,7 +1149,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
}
|
||||
|
||||
// syncSupported tracks whether management supports SyncMappings.
|
||||
// Starts true; set to false on first Unimplemented error.
|
||||
// Starts true; set to false on the first Unimplemented error so
|
||||
// subsequent retries skip straight to GetMappingUpdate.
|
||||
syncSupported := true
|
||||
initialSyncDone := false
|
||||
|
||||
@@ -992,10 +1200,15 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
|
||||
func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
|
||||
supportsCrowdSec := s.crowdsecRegistry.Available()
|
||||
privateCapability := s.Private
|
||||
// Always true: this build enforces ProxyMapping.private via the auth middleware.
|
||||
supportsPrivateService := true
|
||||
return &proto.ProxyCapabilities{
|
||||
SupportsCustomPorts: &s.SupportsCustomPorts,
|
||||
RequireSubdomain: &s.RequireSubdomain,
|
||||
SupportsCrowdsec: &supportsCrowdSec,
|
||||
SupportsCustomPorts: &s.SupportsCustomPorts,
|
||||
RequireSubdomain: &s.RequireSubdomain,
|
||||
SupportsCrowdsec: &supportsCrowdSec,
|
||||
Private: &privateCapability,
|
||||
SupportsPrivateService: &supportsPrivateService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1027,7 +1240,6 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC
|
||||
return fmt.Errorf("create sync stream: %w", err)
|
||||
}
|
||||
|
||||
// Send init message.
|
||||
if err := stream.Send(&proto.SyncMappingsRequest{
|
||||
Msg: &proto.SyncMappingsRequest_Init{
|
||||
Init: &proto.SyncMappingsInit{
|
||||
@@ -1058,6 +1270,10 @@ func isSyncUnimplemented(err error) bool {
|
||||
return ok && st.Code() == codes.Unimplemented
|
||||
}
|
||||
|
||||
// handleSyncMappingsStream consumes batches from a bidirectional SyncMappings
|
||||
// stream, sending an ack after each batch is fully processed. Management waits
|
||||
// for the ack before sending the next batch, providing application-level
|
||||
// back-pressure.
|
||||
func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.ProxyService_SyncMappingsClient, initialSyncDone *bool, connectTime time.Time) error {
|
||||
select {
|
||||
case <-s.routerReady:
|
||||
@@ -1095,39 +1311,10 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, connectTime time.Time) error {
|
||||
select {
|
||||
case <-s.routerReady:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
msg, err := mappingClient.Recv()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
return nil
|
||||
case err != nil:
|
||||
return fmt.Errorf("receive msg: %w", err)
|
||||
}
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotTracker accumulates service IDs during the initial snapshot and
|
||||
// finalises sync state when the complete flag arrives.
|
||||
// finalises sync state when the complete flag arrives. Used by both
|
||||
// handleMappingStream and handleSyncMappingsStream so metric emission and
|
||||
// reconciliation behave identically on either RPC.
|
||||
type snapshotTracker struct {
|
||||
done *bool
|
||||
connectTime time.Time
|
||||
@@ -1171,6 +1358,37 @@ func (t *snapshotTracker) recordBatch(ctx context.Context, s *Server, mappings [
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
|
||||
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, connectTime time.Time) error {
|
||||
select {
|
||||
case <-s.routerReady:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
msg, err := mappingClient.Recv()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
return nil
|
||||
case err != nil:
|
||||
return fmt.Errorf("receive msg: %w", err)
|
||||
}
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reconcileSnapshot removes local mappings that are absent from the snapshot.
|
||||
// This ensures services deleted while the proxy was disconnected get cleaned up.
|
||||
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
|
||||
@@ -1192,17 +1410,29 @@ func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.Se
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
s.ensurePeers(ctx, mappings)
|
||||
// mappingJSONMarshal dumps mappings on one line with zero-value fields visible for debug logs.
|
||||
var mappingJSONMarshal = protojson.MarshalOptions{
|
||||
Multiline: false,
|
||||
EmitUnpopulated: true,
|
||||
UseProtoNames: true,
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
// The full proto dump carries auth_token and header-auth values; gate on debug.
|
||||
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
|
||||
for _, mapping := range mappings {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"type": mapping.GetType(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"mode": mapping.GetMode(),
|
||||
"port": mapping.GetListenPort(),
|
||||
"id": mapping.GetId(),
|
||||
}).Debug("Processing mapping update")
|
||||
if debug {
|
||||
raw, err := mappingJSONMarshal.Marshal(mapping)
|
||||
if err != nil {
|
||||
raw = []byte(fmt.Sprintf("<marshal error: %v>", err))
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"type": mapping.GetType(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"id": mapping.GetId(),
|
||||
"mapping": string(raw),
|
||||
}).Debug("Processing mapping update")
|
||||
}
|
||||
switch mapping.GetType() {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
if err := s.addMapping(ctx, mapping); err != nil {
|
||||
@@ -1228,60 +1458,6 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap
|
||||
}
|
||||
}
|
||||
|
||||
// ensurePeers pre-creates NetBird peers for all unique accounts referenced by
|
||||
// CREATED mappings. Peers for different accounts are created concurrently,
|
||||
// which avoids serializing N×100ms gRPC round-trips during large initial syncs.
|
||||
func (s *Server) ensurePeers(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
// Collect one representative mapping per account that needs a new peer.
|
||||
type peerReq struct {
|
||||
accountID types.AccountID
|
||||
svcKey roundtrip.ServiceKey
|
||||
authToken string
|
||||
svcID types.ServiceID
|
||||
}
|
||||
seen := make(map[types.AccountID]struct{})
|
||||
var reqs []peerReq
|
||||
for _, m := range mappings {
|
||||
if m.GetType() != proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED {
|
||||
continue
|
||||
}
|
||||
accountID := types.AccountID(m.GetAccountId())
|
||||
if _, ok := seen[accountID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[accountID] = struct{}{}
|
||||
if s.netbird.HasClient(accountID) {
|
||||
continue
|
||||
}
|
||||
reqs = append(reqs, peerReq{
|
||||
accountID: accountID,
|
||||
svcKey: s.serviceKeyForMapping(m),
|
||||
authToken: m.GetAuthToken(),
|
||||
svcID: types.ServiceID(m.GetId()),
|
||||
})
|
||||
}
|
||||
|
||||
if len(reqs) <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(reqs))
|
||||
for _, r := range reqs {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.netbird.AddPeer(ctx, r.accountID, r.svcKey, r.authToken, r.svcID); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"account_id": r.accountID,
|
||||
"service_id": r.svcID,
|
||||
"error": err,
|
||||
}).Warn("failed to pre-create peer for account")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// addMapping registers a service mapping and starts the appropriate relay or routes.
|
||||
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
@@ -1353,12 +1529,16 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
|
||||
if s.acme != nil {
|
||||
wildcardHit = s.acme.AddDomain(d, accountID, svcID)
|
||||
}
|
||||
s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
|
||||
httpRoute := nbtcp.Route{
|
||||
Type: nbtcp.RouteHTTP,
|
||||
AccountID: accountID,
|
||||
ServiceID: svcID,
|
||||
Domain: mapping.GetDomain(),
|
||||
})
|
||||
}
|
||||
s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), httpRoute)
|
||||
if s.inbound != nil {
|
||||
s.inbound.AddRoute(accountID, nbtcp.SNIHost(mapping.GetDomain()), httpRoute)
|
||||
}
|
||||
if err := s.updateMapping(ctx, mapping); err != nil {
|
||||
return fmt.Errorf("update mapping for domain %q: %w", d, err)
|
||||
}
|
||||
@@ -1718,7 +1898,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
||||
|
||||
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
|
||||
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
|
||||
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions, mapping.GetPrivate()); err != nil {
|
||||
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
|
||||
}
|
||||
m := s.protoToMapping(ctx, mapping)
|
||||
@@ -1774,6 +1954,9 @@ func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) {
|
||||
}
|
||||
// Remove SNI route from the main router (covers both HTTP and main-port TLS).
|
||||
s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID)
|
||||
if s.inbound != nil {
|
||||
s.inbound.RemoveRoute(types.AccountID(mapping.GetAccountId()), nbtcp.SNIHost(host), svcID)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract and delete tracked custom-port entries atomically.
|
||||
@@ -1861,6 +2044,7 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
|
||||
if d := opts.GetRequestTimeout(); d != nil {
|
||||
pt.RequestTimeout = d.AsDuration()
|
||||
}
|
||||
pt.DirectUpstream = opts.GetDirectUpstream()
|
||||
}
|
||||
pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
|
||||
paths[pathMapping.GetPath()] = pt
|
||||
|
||||
@@ -3067,6 +3067,17 @@ components:
|
||||
$ref: '#/components/schemas/AccessRestrictions'
|
||||
meta:
|
||||
$ref: '#/components/schemas/ServiceMeta'
|
||||
private:
|
||||
type: boolean
|
||||
description: When true, the service is NetBird-only — its target points at a proxy cluster, inbound peers authenticate via their WireGuard tunnel identity (no OIDC), and an ACL policy is auto-generated from access_groups to the cluster's proxy-peer group. Requires mode=http.
|
||||
default: false
|
||||
example: false
|
||||
access_groups:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: NetBird group IDs whose peers may reach this private service over the tunnel. Required when private=true; ignored otherwise. Mutually exclusive with bearer auth (SSO).
|
||||
example: ["group-engineering"]
|
||||
required:
|
||||
- id
|
||||
- name
|
||||
@@ -3147,6 +3158,17 @@ components:
|
||||
$ref: '#/components/schemas/ServiceAuthConfig'
|
||||
access_restrictions:
|
||||
$ref: '#/components/schemas/AccessRestrictions'
|
||||
private:
|
||||
type: boolean
|
||||
description: When true, the service is NetBird-only — its target points at a proxy cluster, inbound peers authenticate via their WireGuard tunnel identity (no OIDC), and an ACL policy is auto-generated from access_groups to the cluster's proxy-peer group. Requires mode=http.
|
||||
default: false
|
||||
example: false
|
||||
access_groups:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: NetBird group IDs whose peers may reach this private service over the tunnel. Required when private=true; ignored otherwise. Mutually exclusive with bearer auth (SSO).
|
||||
example: ["group-engineering"]
|
||||
required:
|
||||
- name
|
||||
- domain
|
||||
@@ -3185,6 +3207,15 @@ components:
|
||||
type: string
|
||||
description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m").
|
||||
example: "2m"
|
||||
direct_upstream:
|
||||
type: boolean
|
||||
description: |
|
||||
When true, the proxy dials this target via the host's network stack
|
||||
instead of through its embedded NetBird client. Use for upstreams
|
||||
reachable without WireGuard (public APIs, LAN services, localhost
|
||||
sidecars).
|
||||
default: false
|
||||
example: false
|
||||
ServiceTarget:
|
||||
type: object
|
||||
properties:
|
||||
@@ -3195,7 +3226,7 @@ components:
|
||||
target_type:
|
||||
type: string
|
||||
description: Target type
|
||||
enum: [peer, host, domain, subnet]
|
||||
enum: [peer, host, domain, subnet, cluster]
|
||||
example: "subnet"
|
||||
path:
|
||||
type: string
|
||||
@@ -3439,6 +3470,10 @@ components:
|
||||
type: boolean
|
||||
description: Whether all active proxies in the cluster have CrowdSec configured
|
||||
example: false
|
||||
private:
|
||||
type: boolean
|
||||
description: True when at least one connected proxy in this cluster is running embedded in a netbird client (`netbird proxy`) and serving over a WireGuard tunnel. Lets the dashboard distinguish per-peer / private clusters from centralised ones.
|
||||
example: false
|
||||
required:
|
||||
- id
|
||||
- address
|
||||
@@ -3494,6 +3529,10 @@ components:
|
||||
type: boolean
|
||||
description: Whether the proxy cluster has CrowdSec configured
|
||||
example: false
|
||||
supports_private:
|
||||
type: boolean
|
||||
description: Whether the proxy cluster supports private (NetBird-only) services. True when at least one connected proxy in the cluster runs embedded in a netbird client.
|
||||
example: false
|
||||
required:
|
||||
- id
|
||||
- domain
|
||||
|
||||
@@ -1063,15 +1063,18 @@ func (e ServiceTargetProtocol) Valid() bool {
|
||||
|
||||
// Defines values for ServiceTargetTargetType.
|
||||
const (
|
||||
ServiceTargetTargetTypeDomain ServiceTargetTargetType = "domain"
|
||||
ServiceTargetTargetTypeHost ServiceTargetTargetType = "host"
|
||||
ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer"
|
||||
ServiceTargetTargetTypeSubnet ServiceTargetTargetType = "subnet"
|
||||
ServiceTargetTargetTypeCluster ServiceTargetTargetType = "cluster"
|
||||
ServiceTargetTargetTypeDomain ServiceTargetTargetType = "domain"
|
||||
ServiceTargetTargetTypeHost ServiceTargetTargetType = "host"
|
||||
ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer"
|
||||
ServiceTargetTargetTypeSubnet ServiceTargetTargetType = "subnet"
|
||||
)
|
||||
|
||||
// Valid indicates whether the value is a known member of the ServiceTargetTargetType enum.
|
||||
func (e ServiceTargetTargetType) Valid() bool {
|
||||
switch e {
|
||||
case ServiceTargetTargetTypeCluster:
|
||||
return true
|
||||
case ServiceTargetTargetTypeDomain:
|
||||
return true
|
||||
case ServiceTargetTargetTypeHost:
|
||||
@@ -3819,6 +3822,9 @@ type ProxyCluster struct {
|
||||
// Online Whether at least one proxy in the cluster has heartbeated within the active window
|
||||
Online bool `json:"online"`
|
||||
|
||||
// Private True when at least one connected proxy in this cluster is running embedded in a netbird client (`netbird proxy`) and serving over a WireGuard tunnel. Lets the dashboard distinguish per-peer / private clusters from centralised ones.
|
||||
Private *bool `json:"private,omitempty"`
|
||||
|
||||
// RequireSubdomain Whether services on this cluster must include a subdomain label
|
||||
RequireSubdomain *bool `json:"require_subdomain,omitempty"`
|
||||
|
||||
@@ -3896,6 +3902,9 @@ type ReverseProxyDomain struct {
|
||||
// SupportsCustomPorts Whether the cluster supports binding arbitrary TCP/UDP ports
|
||||
SupportsCustomPorts *bool `json:"supports_custom_ports,omitempty"`
|
||||
|
||||
// SupportsPrivate Whether the proxy cluster supports private (NetBird-only) services. True when at least one connected proxy in the cluster runs embedded in a netbird client.
|
||||
SupportsPrivate *bool `json:"supports_private,omitempty"`
|
||||
|
||||
// TargetCluster The proxy cluster this domain is validated against (only for custom domains)
|
||||
TargetCluster *string `json:"target_cluster,omitempty"`
|
||||
|
||||
@@ -4085,6 +4094,9 @@ type SentinelOneMatchAttributesNetworkStatus string
|
||||
|
||||
// Service defines model for Service.
|
||||
type Service struct {
|
||||
// AccessGroups NetBird group IDs whose peers may reach this private service over the tunnel. Required when private=true; ignored otherwise. Mutually exclusive with bearer auth (SSO).
|
||||
AccessGroups *[]string `json:"access_groups,omitempty"`
|
||||
|
||||
// AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services.
|
||||
AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"`
|
||||
Auth ServiceAuthConfig `json:"auth"`
|
||||
@@ -4114,6 +4126,9 @@ type Service struct {
|
||||
// PortAutoAssigned Whether the listen port was auto-assigned
|
||||
PortAutoAssigned *bool `json:"port_auto_assigned,omitempty"`
|
||||
|
||||
// Private When true, the service is NetBird-only — its target points at a proxy cluster, inbound peers authenticate via their WireGuard tunnel identity (no OIDC), and an ACL policy is auto-generated from access_groups to the cluster's proxy-peer group. Requires mode=http.
|
||||
Private *bool `json:"private,omitempty"`
|
||||
|
||||
// ProxyCluster The proxy cluster handling this service (derived from domain)
|
||||
ProxyCluster *string `json:"proxy_cluster,omitempty"`
|
||||
|
||||
@@ -4156,6 +4171,9 @@ type ServiceMetaStatus string
|
||||
|
||||
// ServiceRequest defines model for ServiceRequest.
|
||||
type ServiceRequest struct {
|
||||
// AccessGroups NetBird group IDs whose peers may reach this private service over the tunnel. Required when private=true; ignored otherwise. Mutually exclusive with bearer auth (SSO).
|
||||
AccessGroups *[]string `json:"access_groups,omitempty"`
|
||||
|
||||
// AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services.
|
||||
AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"`
|
||||
Auth *ServiceAuthConfig `json:"auth,omitempty"`
|
||||
@@ -4178,6 +4196,9 @@ type ServiceRequest struct {
|
||||
// PassHostHeader When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address
|
||||
PassHostHeader *bool `json:"pass_host_header,omitempty"`
|
||||
|
||||
// Private When true, the service is NetBird-only — its target points at a proxy cluster, inbound peers authenticate via their WireGuard tunnel identity (no OIDC), and an ACL policy is auto-generated from access_groups to the cluster's proxy-peer group. Requires mode=http.
|
||||
Private *bool `json:"private,omitempty"`
|
||||
|
||||
// RewriteRedirects When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain
|
||||
RewriteRedirects *bool `json:"rewrite_redirects,omitempty"`
|
||||
|
||||
@@ -4224,6 +4245,12 @@ type ServiceTargetOptions struct {
|
||||
// CustomHeaders Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected.
|
||||
CustomHeaders *map[string]string `json:"custom_headers,omitempty"`
|
||||
|
||||
// DirectUpstream When true, the proxy dials this target via the host's network stack
|
||||
// instead of through its embedded NetBird client. Use for upstreams
|
||||
// reachable without WireGuard (public APIs, LAN services, localhost
|
||||
// sidecars).
|
||||
DirectUpstream *bool `json:"direct_upstream,omitempty"`
|
||||
|
||||
// PathRewrite Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path.
|
||||
PathRewrite *ServiceTargetOptionsPathRewrite `json:"path_rewrite,omitempty"`
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -34,6 +34,15 @@ service ProxyService {
|
||||
// ValidateSession validates a session token and checks user access permissions.
|
||||
// Called by the proxy after receiving a session token from OIDC callback.
|
||||
rpc ValidateSession(ValidateSessionRequest) returns (ValidateSessionResponse);
|
||||
|
||||
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
|
||||
// checks the resolved user's access against the service's access_groups.
|
||||
// Acts as a fast-path equivalent of OIDC for requests originating on the
|
||||
// netbird mesh: when the source IP maps to a known peer in the calling
|
||||
// account and that peer is in the service's access_groups, the proxy can
|
||||
// issue a session cookie without redirecting through the OIDC flow.
|
||||
// Mirrors ValidateSession's response shape.
|
||||
rpc ValidateTunnelPeer(ValidateTunnelPeerRequest) returns (ValidateTunnelPeerResponse);
|
||||
}
|
||||
|
||||
// ProxyCapabilities describes what a proxy can handle.
|
||||
@@ -45,6 +54,13 @@ message ProxyCapabilities {
|
||||
optional bool require_subdomain = 2;
|
||||
// Whether the proxy has CrowdSec configured and can enforce IP reputation checks.
|
||||
optional bool supports_crowdsec = 3;
|
||||
// Whether the proxy is running embedded in the netbird client and serving
|
||||
// exclusively over the WireGuard tunnel (i.e. `netbird proxy` rather than
|
||||
// the standalone netbird-proxy binary). Surfaces upstream so dashboards can
|
||||
// distinguish per-peer / private clusters from centralised ones.
|
||||
optional bool private = 4;
|
||||
// Whether the proxy enforces ProxyMapping.private (fails closed on ValidateTunnelPeer failure). Management MUST NOT stream private mappings to proxies that don't claim this.
|
||||
optional bool supports_private_service = 5;
|
||||
}
|
||||
|
||||
// GetMappingUpdateRequest is sent to initialise a mapping stream.
|
||||
@@ -86,6 +102,11 @@ message PathTargetOptions {
|
||||
bool proxy_protocol = 5;
|
||||
// Idle timeout before a UDP session is reaped.
|
||||
google.protobuf.Duration session_idle_timeout = 6;
|
||||
// When true, the proxy dials this target via the host's network stack
|
||||
// instead of through the embedded NetBird client. Useful for upstreams
|
||||
// reachable without WireGuard (public APIs, LAN services, localhost
|
||||
// sidecars). Defaults to false — embedded client is the standard path.
|
||||
bool direct_upstream = 7;
|
||||
}
|
||||
|
||||
message PathMapping {
|
||||
@@ -138,6 +159,8 @@ message ProxyMapping {
|
||||
// For L4/TLS: the port the proxy listens on.
|
||||
int32 listen_port = 11;
|
||||
AccessRestrictions access_restrictions = 12;
|
||||
// NetBird-only: the proxy MUST call ValidateTunnelPeer and fail closed; operator auth schemes are bypassed.
|
||||
bool private = 13;
|
||||
}
|
||||
|
||||
// SendAccessLogRequest consists of one or more AccessLogs from a Proxy.
|
||||
@@ -213,6 +236,25 @@ message SendStatusUpdateRequest {
|
||||
ProxyStatus status = 3;
|
||||
bool certificate_issued = 4;
|
||||
optional string error_message = 5;
|
||||
// Per-account inbound listener state for the account that owns
|
||||
// service_id. Populated only when --private-inbound is enabled and the
|
||||
// embedded client for the account is up. Field numbers >=50 reserved
|
||||
// for observability extensions.
|
||||
optional ProxyInboundListener inbound_listener = 50;
|
||||
}
|
||||
|
||||
// ProxyInboundListener describes a per-account inbound listener that the
|
||||
// proxy has bound on the embedded netstack of the account's WireGuard
|
||||
// client. Surfaced so dashboards can render "this account is reachable
|
||||
// at <tunnel_ip>:<https_port> on this proxy".
|
||||
message ProxyInboundListener {
|
||||
// Tunnel IP the embedded netstack listens on. Same address other peers
|
||||
// in the account see for the proxy peer.
|
||||
string tunnel_ip = 1;
|
||||
// TLS port served on tunnel_ip (auto-detected, default 443).
|
||||
uint32 https_port = 2;
|
||||
// Plain-HTTP port served on tunnel_ip (auto-detected, default 80).
|
||||
uint32 http_port = 3;
|
||||
}
|
||||
|
||||
// SendStatusUpdateResponse is intentionally empty to allow for future expansion
|
||||
@@ -254,6 +296,52 @@ message ValidateSessionResponse {
|
||||
string user_id = 2;
|
||||
string user_email = 3;
|
||||
string denied_reason = 4;
|
||||
// peer_group_ids carries the calling user's group memberships so the
|
||||
// proxy can authorise policy-aware middlewares without an additional
|
||||
// management round-trip.
|
||||
repeated string peer_group_ids = 5;
|
||||
// peer_group_names carries the human-readable display names for the
|
||||
// ids in peer_group_ids, ordered identically (positional pairing).
|
||||
// Stamped onto upstream requests as X-NetBird-Groups so downstream
|
||||
// services can read names rather than opaque ids.
|
||||
repeated string peer_group_names = 6;
|
||||
}
|
||||
|
||||
// ValidateTunnelPeerRequest carries the inbound peer's tunnel IP and the
|
||||
// service domain whose group requirements should gate access. The calling
|
||||
// account is inferred from the proxy's gRPC metadata (ProxyToken).
|
||||
message ValidateTunnelPeerRequest {
|
||||
string tunnel_ip = 1;
|
||||
string domain = 2;
|
||||
}
|
||||
|
||||
// ValidateTunnelPeerResponse mirrors ValidateSessionResponse plus a freshly
|
||||
// minted session_token: when valid is true, the proxy installs the token as
|
||||
// a session cookie so subsequent requests skip the management round-trip,
|
||||
// matching the OIDC flow's UX. denied_reason values:
|
||||
// "peer_not_found" — no peer with that tunnel IP in the calling account
|
||||
// "no_user" — peer exists but is not bound to a user
|
||||
// "service_not_found"
|
||||
// "account_mismatch"
|
||||
// "not_in_group" — peer resolved but not in service.access_groups
|
||||
message ValidateTunnelPeerResponse {
|
||||
bool valid = 1;
|
||||
string user_id = 2;
|
||||
string user_email = 3;
|
||||
string denied_reason = 4;
|
||||
// session_token is set only when valid is true. Same shape as the JWT
|
||||
// the OIDC flow produces — proxy installs it via setSessionCookie so the
|
||||
// tunnel fast-path is indistinguishable from OIDC for subsequent requests.
|
||||
string session_token = 5;
|
||||
// peer_group_ids carries the resolved peer's user group memberships so
|
||||
// the proxy can authorise policy-aware middlewares without an additional
|
||||
// management round-trip.
|
||||
repeated string peer_group_ids = 6;
|
||||
// peer_group_names carries the human-readable display names for the
|
||||
// ids in peer_group_ids, ordered identically (positional pairing).
|
||||
// Stamped onto upstream requests as X-NetBird-Groups so downstream
|
||||
// services can read names rather than opaque ids.
|
||||
repeated string peer_group_names = 7;
|
||||
}
|
||||
|
||||
// SyncMappingsRequest is sent by the proxy on the bidirectional SyncMappings
|
||||
@@ -287,3 +375,4 @@ message SyncMappingsResponse {
|
||||
// initial_sync_complete is set on the last message of the initial snapshot.
|
||||
bool initial_sync_complete = 2;
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,14 @@ type ProxyServiceClient interface {
|
||||
// ValidateSession validates a session token and checks user access permissions.
|
||||
// Called by the proxy after receiving a session token from OIDC callback.
|
||||
ValidateSession(ctx context.Context, in *ValidateSessionRequest, opts ...grpc.CallOption) (*ValidateSessionResponse, error)
|
||||
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
|
||||
// checks the resolved user's access against the service's access_groups.
|
||||
// Acts as a fast-path equivalent of OIDC for requests originating on the
|
||||
// netbird mesh: when the source IP maps to a known peer in the calling
|
||||
// account and that peer is in the service's access_groups, the proxy can
|
||||
// issue a session cookie without redirecting through the OIDC flow.
|
||||
// Mirrors ValidateSession's response shape.
|
||||
ValidateTunnelPeer(ctx context.Context, in *ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*ValidateTunnelPeerResponse, error)
|
||||
}
|
||||
|
||||
type proxyServiceClient struct {
|
||||
@@ -162,6 +170,15 @@ func (c *proxyServiceClient) ValidateSession(ctx context.Context, in *ValidateSe
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *proxyServiceClient) ValidateTunnelPeer(ctx context.Context, in *ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*ValidateTunnelPeerResponse, error) {
|
||||
out := new(ValidateTunnelPeerResponse)
|
||||
err := c.cc.Invoke(ctx, "/management.ProxyService/ValidateTunnelPeer", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ProxyServiceServer is the server API for ProxyService service.
|
||||
// All implementations must embed UnimplementedProxyServiceServer
|
||||
// for forward compatibility
|
||||
@@ -183,6 +200,14 @@ type ProxyServiceServer interface {
|
||||
// ValidateSession validates a session token and checks user access permissions.
|
||||
// Called by the proxy after receiving a session token from OIDC callback.
|
||||
ValidateSession(context.Context, *ValidateSessionRequest) (*ValidateSessionResponse, error)
|
||||
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
|
||||
// checks the resolved user's access against the service's access_groups.
|
||||
// Acts as a fast-path equivalent of OIDC for requests originating on the
|
||||
// netbird mesh: when the source IP maps to a known peer in the calling
|
||||
// account and that peer is in the service's access_groups, the proxy can
|
||||
// issue a session cookie without redirecting through the OIDC flow.
|
||||
// Mirrors ValidateSession's response shape.
|
||||
ValidateTunnelPeer(context.Context, *ValidateTunnelPeerRequest) (*ValidateTunnelPeerResponse, error)
|
||||
mustEmbedUnimplementedProxyServiceServer()
|
||||
}
|
||||
|
||||
@@ -214,6 +239,9 @@ func (UnimplementedProxyServiceServer) GetOIDCURL(context.Context, *GetOIDCURLRe
|
||||
func (UnimplementedProxyServiceServer) ValidateSession(context.Context, *ValidateSessionRequest) (*ValidateSessionResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ValidateSession not implemented")
|
||||
}
|
||||
func (UnimplementedProxyServiceServer) ValidateTunnelPeer(context.Context, *ValidateTunnelPeerRequest) (*ValidateTunnelPeerResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ValidateTunnelPeer not implemented")
|
||||
}
|
||||
func (UnimplementedProxyServiceServer) mustEmbedUnimplementedProxyServiceServer() {}
|
||||
|
||||
// UnsafeProxyServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
@@ -382,6 +410,24 @@ func _ProxyService_ValidateSession_Handler(srv interface{}, ctx context.Context,
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ProxyService_ValidateTunnelPeer_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ValidateTunnelPeerRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ProxyServiceServer).ValidateTunnelPeer(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/management.ProxyService/ValidateTunnelPeer",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ProxyServiceServer).ValidateTunnelPeer(ctx, req.(*ValidateTunnelPeerRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// ProxyService_ServiceDesc is the grpc.ServiceDesc for ProxyService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -413,6 +459,10 @@ var ProxyService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "ValidateSession",
|
||||
Handler: _ProxyService_ValidateSession_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ValidateTunnelPeer",
|
||||
Handler: _ProxyService_ValidateTunnelPeer_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user