Compare commits

...

2 Commits

Author SHA1 Message Date
pascal
dc2d337071 respect requires approval when adding peer 2026-06-17 11:06:00 +02:00
pascal
0f78a73928 remove nmap calc from login 2026-06-16 17:07:59 +02:00
20 changed files with 318 additions and 158 deletions

View File

@@ -585,66 +585,66 @@ func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
b.next.Reset(d)
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, nil, 0, nil
return emptyMap, nil, 0, nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
postureChecks, err := c.getPeerPostureChecks(account, peerID)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
proxyNetworkMap, ok := proxyNetworkMaps[peerID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return peer, networkMap, postureChecks, dnsFwdPort, nil
return networkMap, postureChecks, dnsFwdPort, nil
}
// GetDNSDomain returns the configured dnsDomain

View File

@@ -23,7 +23,7 @@ type Controller interface {
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error)
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)

View File

@@ -127,21 +127,20 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
}
// GetValidatedPeerWithMap mocks base method.
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(int64)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, peerID)
ret0, _ := ret[0].(*types.NetworkMap)
ret1, _ := ret[1].([]*posture.Checks)
ret2, _ := ret[2].(int64)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
}
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, peerID)
}
// OnPeerConnected mocks base method.

View File

@@ -242,7 +242,7 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee
},
}
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
_, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
if err != nil {
return fmt.Errorf("failed to create proxy peer: %w", err)
}

View File

@@ -778,7 +778,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
}
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
peer, network, postureChecks, enableSSH, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: peerMeta,
@@ -792,7 +792,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
return nil, mapError(ctx, err)
}
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
loginResp, err := s.prepareLoginResponse(ctx, peer, network, postureChecks, enableSSH)
if err != nil {
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
@@ -895,7 +895,7 @@ func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMess
}, nil
}
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, network *types.Network, postureChecks []*posture.Checks, enableSSH bool) (*proto.LoginResponse, error) {
var relayToken *Token
var err error
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
@@ -914,7 +914,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
PeerConfig: toPeerConfig(peer, network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, enableSSH),
Checks: toProtocolChecks(ctx, postureChecks),
}

View File

@@ -70,7 +70,7 @@ type Manager interface {
UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
@@ -109,7 +109,7 @@ type Manager interface {
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager

View File

@@ -80,14 +80,15 @@ func (mr *MockManagerMockRecorder) AccountExists(ctx, accountID interface{}) *go
}
// AddPeer mocks base method.
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddPeer", ctx, accountID, setupKey, userID, p, temporary)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret1, _ := ret[1].(*types.Network)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
ret3, _ := ret[3].(bool)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
}
// AddPeer indicates an expected call of AddPeer.
@@ -1289,14 +1290,15 @@ func (mr *MockManagerMockRecorder) ListUsers(ctx, accountID interface{}) *gomock
}
// LoginPeer mocks base method.
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoginPeer", ctx, login)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret1, _ := ret[1].(*types.Network)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
ret3, _ := ret[3].(bool)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
}
// LoginPeer indicates an expected call of LoginPeer.

View File

@@ -84,7 +84,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account
setupKey = key.Key
}
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
_, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
if err != nil {
t.Error("expected to add new peer successfully after creating new account, but failed", err)
}
@@ -1092,7 +1092,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}, false)
@@ -1156,7 +1156,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedUserID := userID
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}, false)
@@ -1504,7 +1504,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
peerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
}, false)
@@ -1826,7 +1826,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -1882,7 +1882,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -1927,7 +1927,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
}, false)
@@ -2017,7 +2017,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
}, false)
@@ -2080,7 +2080,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -3276,7 +3276,7 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
Status: &nbpeer.PeerStatus{
@@ -3444,7 +3444,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
WireGuardPubKey: account.Peers["peer-1"].Key,
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
@@ -3513,7 +3513,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
@@ -3908,13 +3908,13 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
key2, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
}, false)
require.NoError(t, err, "unable to add peer1")
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}, false)

View File

@@ -1663,7 +1663,7 @@ func addPeerToAccount(t *testing.T, manager *DefaultAccountManager, _, setupKeyK
key, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKeyKey, "", &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKeyKey, "", &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: key.PublicKey().String()},
}, false)

View File

@@ -298,11 +298,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
return nil, err
}
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
savedPeer1, _, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
if err != nil {
return nil, err
}
_, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
_, _, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
if err != nil {
return nil, err
}

View File

@@ -55,7 +55,7 @@ func TestGroupIPv6Assignment(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
peer, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "ipv6-test-host"},
}, false)

View File

@@ -479,7 +479,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
return
}
peer, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true)
peer, _, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -728,7 +728,7 @@ func Test_LoginPerformance(t *testing.T) {
}
login := func() error {
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
_, _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return err
@@ -746,7 +746,7 @@ func Test_LoginPerformance(t *testing.T) {
go func(peerLogin types.PeerLogin, counterStart *int32) {
defer wgPeer.Done()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
_, _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return

View File

@@ -45,7 +45,7 @@ type MockAccountManager struct {
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
@@ -98,7 +98,7 @@ type MockAccountManager struct {
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
ExtendPeerSessionFunc func(ctx context.Context, peerPubKey, userID string) (time.Time, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
@@ -424,11 +424,11 @@ func (am *MockAccountManager) AddPeer(
userId string,
peer *nbpeer.Peer,
temporary bool,
) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
if am.AddPeerFunc != nil {
return am.AddPeerFunc(ctx, accountID, setupKey, userId, peer, temporary)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
return nil, nil, nil, false, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
}
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
@@ -862,11 +862,11 @@ func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, account
}
// LoginPeer mocks LoginPeer of the AccountManager interface
func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
if am.LoginPeerFunc != nil {
return am.LoginPeerFunc(ctx, login)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
return nil, nil, nil, false, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
}
// ExtendPeerSession mocks ExtendPeerSession of the AccountManager interface

View File

@@ -896,11 +896,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
return nil, err
}
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false)
_, _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false)
if err != nil {
return nil, err
}
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false)
_, _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false)
if err != nil {
return nil, err
}

View File

@@ -718,10 +718,10 @@ func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, en
// to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
// no auth method provided => reject access
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
return nil, nil, nil, false, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
}
upperKey := strings.ToUpper(setupKey)
@@ -737,7 +737,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
// The connecting peer should be able to recover with a retry.
_, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
return nil, nil, nil, false, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
opEvent := &activity.Event{
@@ -748,7 +748,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
peerAddConfig, err := am.processPeerAddAuth(ctx, accountID, userID, encodedHashedKey, peer, temporary, addedByUser, addedBySetupKey, opEvent)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
accountID = peerAddConfig.AccountID
ephemeral := peerAddConfig.Ephemeral
@@ -763,7 +763,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
return nil, nil, nil, false, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
}
registrationTime := time.Now().UTC()
@@ -789,7 +789,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get account settings: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed to get account settings: %w", err)
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
@@ -807,30 +807,30 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed getting network: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed getting network: %w", err)
}
maxAttempts := 10
for attempt := 1; attempt <= maxAttempts; attempt++ {
netPrefix, err := netip.ParsePrefix(network.Net.String())
if err != nil {
return nil, nil, nil, fmt.Errorf("parse network prefix: %w", err)
return nil, nil, nil, false, fmt.Errorf("parse network prefix: %w", err)
}
freeIP, err := types.AllocateRandomPeerIP(netPrefix)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed to get free IP: %w", err)
}
var freeLabel string
if ephemeral || attempt > 1 {
freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed to get free DNS label: %w", err)
}
} else {
freeLabel, err = nbdns.GetParsedDomainLabel(peer.Meta.Hostname)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed to get free DNS label: %w", err)
}
}
newPeer.DNSLabel = freeLabel
@@ -852,11 +852,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
if allocate {
v6Prefix, err := netip.ParsePrefix(network.NetV6.String())
if err != nil {
return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err)
return nil, nil, nil, false, fmt.Errorf("parse IPv6 prefix: %w", err)
}
freeIPv6, err := types.AllocateRandomPeerIPv6(v6Prefix)
if err != nil {
return nil, nil, nil, fmt.Errorf("allocate peer IPv6: %w", err)
return nil, nil, nil, false, fmt.Errorf("allocate peer IPv6: %w", err)
}
newPeer.IPv6 = freeIPv6
}
@@ -929,10 +929,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
continue
}
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed to add peer to database: %w", err)
}
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
return nil, nil, nil, false, fmt.Errorf("new peer is nil")
}
opEvent.TargetID = newPeer.ID
@@ -940,7 +940,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
if !addedByUser {
opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName
}
if newPeer.Status != nil && newPeer.Status.RequiresApproval {
requiresApproval := newPeer.Status != nil && newPeer.Status.RequiresApproval
if requiresApproval {
opEvent.Meta["pending_approval"] = true
}
@@ -948,18 +949,18 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
}
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
network, postureChecks, enableSSH, err := getPeerLoginInfo(ctx, am.Store, accountID, newPeer, !requiresApproval)
if err != nil {
return p, nmap, pc, err
return nil, nil, nil, false, err
}
changedPeerIDs := []string{newPeer.ID}
affectedPeerIDs := affectedPeerIDsFromNetworkMap(nmap, newPeer.ID)
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersAdded(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
return p, nmap, pc, nil
return newPeer, network, postureChecks, enableSSH, nil
}
func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
@@ -1041,7 +1042,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
resPeer, nmap, resPostureChecks, dnsFwdPort, err := am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
nmap, resPostureChecks, dnsFwdPort, err := am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer.ID)
if err != nil {
return nil, nil, nil, 0, err
}
@@ -1054,7 +1055,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
}
return resPeer, nmap, resPostureChecks, dnsFwdPort, nil
return peer, nmap, resPostureChecks, dnsFwdPort, nil
}
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
@@ -1085,7 +1086,7 @@ func (am *DefaultAccountManager) markConnectedAffectedPeers(ctx context.Context,
return affectedPeerIDsFromNetworkMap(nmap, peerID)
}
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
// Try registering it.
@@ -1101,12 +1102,12 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo
}
log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)
return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
return nil, nil, nil, false, status.Errorf(status.Internal, "failed while logging in peer")
}
// LoginPeer logs in or registers a peer.
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return am.handlePeerLoginNotFound(ctx, login, err)
@@ -1118,20 +1119,17 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if login.UserID == "" {
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
}
var peer *nbpeer.Peer
var updateRemotePeers bool
var isPeerUpdated bool
var ipv6CapabilityChanged bool
var postureChecks []*posture.Checks
var shouldStorePeer bool
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -1140,9 +1138,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return err
}
// this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false
if login.UserID != "" {
if peer.UserID != login.UserID {
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
@@ -1156,7 +1151,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if changed {
shouldStorePeer = true
updateRemotePeers = true
}
}
@@ -1165,23 +1159,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return err
}
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
isPeerUpdated, _ = peer.UpdateMetaIfNew(login.Meta)
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
if isPeerUpdated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
shouldStorePeer = true
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
}
if peer.SSHKey != login.SSHKey {
peer.SSHKey = login.SSHKey
shouldStorePeer = true
updateRemotePeers = true
}
if !peer.AllowExtraDNSLabels && len(login.ExtraDNSLabels) > 0 {
@@ -1197,28 +1177,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil
})
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
network, postureChecks, enableSSH, err := getPeerLoginInfo(ctx, am.Store, accountID, peer, !isRequiresApproval)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
if updateRemotePeers || isStatusChanged || ipv6CapabilityChanged || (isPeerUpdated && len(postureChecks) > 0) {
if isStatusChanged || shouldStorePeer {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, isRequiresApproval, isPeerUpdated, len(postureChecks) > 0)
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
return nil, nil, nil, false, fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
return p, nmap, pc, nil
return peer, network, postureChecks, enableSSH, nil
}
// ExtendPeerSession refreshes the peer's SSO session deadline by updating
@@ -1294,6 +1274,50 @@ func (am *DefaultAccountManager) ExtendPeerSession(ctx context.Context, peerPubK
return refreshed.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration), nil
}
// getPeerLoginInfo computes the login/register response data (network, posture
// checks, SSH) from the store without building the peer's full network map.
func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer, isValid bool) (*types.Network, []*posture.Checks, bool, error) {
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, false, fmt.Errorf("get account network: %w", err)
}
if !isValid {
return network, nil, false, nil
}
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return nil, nil, false, err
}
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
if err != nil {
return nil, nil, false, err
}
return network, postureChecks, enableSSH, nil
}
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return false, err
}
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
for _, g := range peerGroups {
peerGroupIDs[g.ID] = struct{}{}
}
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
}
// getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)

View File

@@ -205,7 +205,7 @@ func testGetNetworkMapGeneral(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
}, false)
@@ -219,7 +219,7 @@ func testGetNetworkMapGeneral(t *testing.T) {
t.Fatal(err)
return
}
_, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}, false)
@@ -278,7 +278,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
}, false)
@@ -292,7 +292,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
t.Fatal(err)
return
}
peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}, false)
@@ -454,7 +454,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
}, false)
@@ -468,7 +468,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
t.Fatal(err)
return
}
_, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}, false)
@@ -526,7 +526,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}, false)
@@ -542,7 +542,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
}
// the second peer added with a setup key
peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}, false)
@@ -698,7 +698,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return
}
_, _, _, err = manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
}, false)
@@ -707,7 +707,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return
}
_, _, _, err = manager.AddPeer(context.Background(), "", "", adminUser, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", adminUser, &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}, false)
@@ -1332,7 +1332,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
},
}
addedPeer, _, _, err := am.AddPeer(context.Background(), "", "", existingUserID, newPeer, false)
addedPeer, _, _, _, err := am.AddPeer(context.Background(), "", "", existingUserID, newPeer, false)
require.NoError(t, err)
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
@@ -1465,7 +1465,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
ExtraDNSLabels: newPeerTemplate.ExtraDNSLabels,
}
addedPeer, _, _, err := am.AddPeer(context.Background(), "", tc.existingSetupKeyID, "", currentPeer, false)
addedPeer, _, _, _, err := am.AddPeer(context.Background(), "", tc.existingSetupKeyID, "", currentPeer, false)
if tc.expectAddPeerError {
require.Error(t, err, "Expected an error when adding peer with setup key: %s", tc.existingSetupKeyID)
@@ -1577,7 +1577,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
SSHEnabled: false,
}
_, _, _, err = am.AddPeer(context.Background(), "", faultyKey, "", newPeer, false)
_, _, _, _, err = am.AddPeer(context.Background(), "", faultyKey, "", newPeer, false)
require.Error(t, err)
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
@@ -1723,7 +1723,7 @@ func Test_LoginPeer(t *testing.T) {
if sk.AllowExtraDNSLabels {
currentPeer.ExtraDNSLabels = newPeerTemplate.ExtraDNSLabels
}
_, _, _, err = am.AddPeer(context.Background(), "", tc.setupKey, "", currentPeer, false)
_, _, _, _, err = am.AddPeer(context.Background(), "", tc.setupKey, "", currentPeer, false)
require.NoError(t, err, "Expected no error when adding peer with setup key: %s", tc.setupKey)
loginInput := types.PeerLogin{
@@ -1739,12 +1739,12 @@ func Test_LoginPeer(t *testing.T) {
loginInput.ExtraDNSLabels = tc.extraDNSLabels
}
loggedinPeer, networkMap, postureChecks, loginErr := am.LoginPeer(context.Background(), loginInput)
loggedinPeer, network, postureChecks, _, loginErr := am.LoginPeer(context.Background(), loginInput)
if tc.expectLoginError {
require.Error(t, loginErr, "Expected an error during LoginPeer with setup key: %s", tc.setupKey)
assert.Contains(t, loginErr.Error(), tc.expectedErrorMsgSubstring, "Error message mismatch")
assert.Nil(t, loggedinPeer, "LoggedinPeer should be nil on error")
assert.Nil(t, networkMap, "NetworkMap should be nil on error")
assert.Nil(t, network, "Network should be nil on error")
assert.Nil(t, postureChecks, "PostureChecks should be empty or nil on error")
return
}
@@ -1757,7 +1757,7 @@ func Test_LoginPeer(t *testing.T) {
} else {
assert.Equal(t, currentPeer.ExtraDNSLabels, loggedinPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch on loggedinPeer")
}
assert.NotNil(t, networkMap, "networkMap should not be nil on success")
assert.NotNil(t, network, "network should not be nil on success")
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
@@ -1863,7 +1863,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{
peer4, _, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}, false)
@@ -1986,7 +1986,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{
peer4, _, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{
Key: expectedPeerKey,
LoginExpirationEnabled: true,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
@@ -2053,7 +2053,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer5, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{
peer5, _, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{
Key: expectedPeerKey,
LoginExpirationEnabled: true,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
@@ -2108,7 +2108,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer6, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser3", &nbpeer.Peer{
peer6, _, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser3", &nbpeer.Peer{
Key: expectedPeerKey,
LoginExpirationEnabled: true,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
@@ -2286,7 +2286,7 @@ func Test_AddPeer(t *testing.T) {
<-start
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", newPeer, false)
_, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", newPeer, false)
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
@@ -2366,7 +2366,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
},
}
_, _, _, err = manager.AddPeer(context.Background(), "", "", pendingUser.Id, peer, false)
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", pendingUser.Id, peer, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "user pending approval cannot add peers")
}
@@ -2401,7 +2401,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
},
}
_, _, _, err = manager.AddPeer(context.Background(), "", "", regularUser.Id, peer, false)
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", regularUser.Id, peer, false)
require.NoError(t, err, "Regular user should be able to add peers")
}
@@ -2444,7 +2444,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
WtVersion: "0.28.0",
},
}
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", pendingUser.Id, newPeer, false)
existingPeer, _, _, _, err := manager.AddPeer(context.Background(), "", "", pendingUser.Id, newPeer, false)
require.NoError(t, err)
// Now set the user back to pending approval after peer was created
@@ -2463,7 +2463,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
},
}
_, _, _, err = manager.LoginPeer(context.Background(), login)
_, _, _, _, err = manager.LoginPeer(context.Background(), login)
require.Error(t, err)
e, ok := status.FromError(err)
require.True(t, ok, "error is not a gRPC status error")
@@ -2500,7 +2500,7 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
WtVersion: "0.28.0",
},
}
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", regularUser.Id, newPeer, false)
existingPeer, _, _, _, err := manager.AddPeer(context.Background(), "", "", regularUser.Id, newPeer, false)
require.NoError(t, err)
// Try to login with regular user
@@ -2513,7 +2513,7 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
},
}
_, _, _, err = manager.LoginPeer(context.Background(), login)
_, _, _, _, err = manager.LoginPeer(context.Background(), login)
require.NoError(t, err, "Regular user should be able to login peers")
}
@@ -2837,7 +2837,7 @@ func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) {
// Add first peer with hostname that produces DNS label "netbird1"
key1, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "netbird1.netbird.cloud"},
}, false)
@@ -2847,7 +2847,7 @@ func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) {
// Add second peer with a different hostname
key2, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "ip-10-29-5-130"},
}, false)
@@ -2871,7 +2871,7 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
key1, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "web-server"},
}, false)
@@ -2881,7 +2881,7 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
// Add second peer and rename it to a unique FQDN whose first label doesn't collide
key2, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "old-name"},
}, false)

View File

@@ -1156,6 +1156,47 @@ func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
}
// PeerSSHEnabledFromPolicies is the network-map-free equivalent of the sshEnabled
// determination in GetPeerConnectionResources / CalculateNetworkMapFromComponents.
func PeerSSHEnabledFromPolicies(policies []*Policy, peerID string, peerGroupIDs map[string]struct{}, peerSSHEnabled bool) bool {
for _, policy := range policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
isSSHRule := rule.Protocol == PolicyRuleProtocolNetbirdSSH ||
(policyRuleImpliesLegacySSH(rule) && peerSSHEnabled)
if !isSSHRule {
continue
}
if ruleHasDestination(rule, peerID, peerGroupIDs) {
return true
}
}
}
return false
}
func ruleHasDestination(rule *PolicyRule, peerID string, peerGroupIDs map[string]struct{}) bool {
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
return rule.DestinationResource.ID == peerID
}
for _, groupID := range rule.Destinations {
if _, ok := peerGroupIDs[groupID]; ok {
return true
}
}
return false
}
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
for _, pr := range portRanges {
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {

View File

@@ -1233,3 +1233,97 @@ func TestComponents_DisabledRuleInEnabledPolicy(t *testing.T) {
assert.True(t, has3000, "enabled rule should generate firewall rule for port 3000")
assert.False(t, has3001, "disabled rule should NOT generate firewall rule for port 3001")
}
func peerGroupIDSet(account *types.Account, peerID string) map[string]struct{} {
return account.GetPeerGroups(peerID)
}
func assertSSHEquivalence(t *testing.T, account *types.Account, peerID string, validatedPeers map[string]struct{}) {
t.Helper()
nm := componentsNetworkMap(account, peerID, validatedPeers)
require.NotNil(t, nm)
got := types.PeerSSHEnabledFromPolicies(account.Policies, peerID, peerGroupIDSet(account, peerID), account.Peers[peerID].SSHEnabled)
assert.Equalf(t, nm.EnableSSH, got, "PeerSSHEnabledFromPolicies mismatch for %s", peerID)
}
func TestPeerSSHEnabledFromPolicies_MatchesMap_NetbirdSSHProtocol(t *testing.T) {
account, validatedPeers := scalableTestAccount(20, 2)
account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: "test-account",
Rules: []*types.PolicyRule{{
ID: "rule-ssh", Name: "Allow SSH", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Bidirectional: false,
Sources: []string{"group-0"}, Destinations: []string{"group-1"},
AuthorizedGroups: map[string][]string{"ssh-users": {"root"}},
}},
})
assertSSHEquivalence(t, account, "peer-10", validatedPeers)
assertSSHEquivalence(t, account, "peer-0", validatedPeers)
}
func TestPeerSSHEnabledFromPolicies_MatchesMap_NoSSHPolicy(t *testing.T) {
account, validatedPeers := scalableTestAccount(20, 2)
assertSSHEquivalence(t, account, "peer-0", validatedPeers)
}
func TestPeerSSHEnabledFromPolicies_MatchesMap_LegacyImpliedSSH(t *testing.T) {
account, validatedPeers := scalableTestAccount(20, 2)
account.Peers["peer-10"].SSHEnabled = true
assertSSHEquivalence(t, account, "peer-10", validatedPeers)
assertSSHEquivalence(t, account, "peer-11", validatedPeers)
}
func TestPeerSSHEnabledFromPolicies_MatchesMap_PeerAsDestinationResource(t *testing.T) {
account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2)
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-ssh-res", Name: "SSH to peer", Enabled: true, AccountID: "test-account",
Rules: []*types.PolicyRule{{
ID: "rule-ssh-res", Name: "SSH to peer-5", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Sources: []string{"group-0"},
DestinationResource: types.Resource{ID: "peer-5", Type: types.ResourceTypePeer},
}},
})
assertSSHEquivalence(t, account, "peer-5", validatedPeers)
assertSSHEquivalence(t, account, "peer-6", validatedPeers)
}
func TestPeerSSHEnabledFromPolicies_MatchesMap_DisabledSSHPolicy(t *testing.T) {
account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2)
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-ssh-off", Name: "SSH disabled", Enabled: false, AccountID: "test-account",
Rules: []*types.PolicyRule{{
ID: "rule-ssh-off", Name: "Allow SSH", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Sources: []string{"group-0"}, Destinations: []string{"group-1"},
}},
})
assertSSHEquivalence(t, account, "peer-10", validatedPeers)
}
func TestPeerSSHEnabledFromPolicies_MatchesMap_Sweep(t *testing.T) {
account, validatedPeers := scalableTestAccount(60, 6)
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-ssh-sweep", Name: "SSH sweep", Enabled: true, AccountID: "test-account",
Rules: []*types.PolicyRule{{
ID: "rule-ssh-sweep", Name: "Allow SSH", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Sources: []string{"group-0"}, Destinations: []string{"group-2"},
}},
})
for peerID := range account.Peers {
account.Peers[peerID].SSHEnabled = len(peerID)%2 == 0
}
for peerID := range account.Peers {
if _, ok := validatedPeers[peerID]; !ok {
continue
}
assertSSHEquivalence(t, account, peerID, validatedPeers)
}
}

View File

@@ -1565,7 +1565,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
require.NoError(t, err)
expectedPeerKey := key.PublicKey().String()
peer4, _, _, err := manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{
peer4, _, _, _, err := manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}, false)