mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-17 13:29:57 +00:00
Compare commits
2 Commits
dnsfwd-ext
...
feature/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc2d337071 | ||
|
|
0f78a73928 |
@@ -585,66 +585,66 @@ func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
|
||||
b.next.Reset(d)
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
emptyMap := &types.NetworkMap{
|
||||
Network: network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, nil, 0, nil
|
||||
return emptyMap, nil, 0, nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peerID]
|
||||
if ok {
|
||||
networkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||
return networkMap, postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
|
||||
@@ -23,7 +23,7 @@ type Controller interface {
|
||||
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
|
||||
@@ -127,21 +127,20 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithMap mocks base method.
|
||||
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMap)
|
||||
ret2, _ := ret[2].([]*posture.Checks)
|
||||
ret3, _ := ret[3].(int64)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, peerID)
|
||||
ret0, _ := ret[0].(*types.NetworkMap)
|
||||
ret1, _ := ret[1].([]*posture.Checks)
|
||||
ret2, _ := ret[2].(int64)
|
||||
ret3, _ := ret[3].(error)
|
||||
return ret0, ret1, ret2, ret3
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, peerID)
|
||||
}
|
||||
|
||||
// OnPeerConnected mocks base method.
|
||||
|
||||
@@ -242,7 +242,7 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
|
||||
_, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create proxy peer: %w", err)
|
||||
}
|
||||
|
||||
@@ -778,7 +778,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
|
||||
peer, network, postureChecks, enableSSH, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
|
||||
WireGuardPubKey: peerKey.String(),
|
||||
SSHKey: string(sshKey),
|
||||
Meta: peerMeta,
|
||||
@@ -792,7 +792,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, network, postureChecks, enableSSH)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
@@ -895,7 +895,7 @@ func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMess
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, network *types.Network, postureChecks []*posture.Checks, enableSSH bool) (*proto.LoginResponse, error) {
|
||||
var relayToken *Token
|
||||
var err error
|
||||
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
|
||||
@@ -914,7 +914,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
|
||||
PeerConfig: toPeerConfig(peer, network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, enableSSH),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ type Manager interface {
|
||||
UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
|
||||
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
@@ -109,7 +109,7 @@ type Manager interface {
|
||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) // used by peer gRPC API
|
||||
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
|
||||
@@ -80,14 +80,15 @@ func (mr *MockManagerMockRecorder) AccountExists(ctx, accountID interface{}) *go
|
||||
}
|
||||
|
||||
// AddPeer mocks base method.
|
||||
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddPeer", ctx, accountID, setupKey, userID, p, temporary)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMap)
|
||||
ret1, _ := ret[1].(*types.Network)
|
||||
ret2, _ := ret[2].([]*posture.Checks)
|
||||
ret3, _ := ret[3].(error)
|
||||
return ret0, ret1, ret2, ret3
|
||||
ret3, _ := ret[3].(bool)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
}
|
||||
|
||||
// AddPeer indicates an expected call of AddPeer.
|
||||
@@ -1289,14 +1290,15 @@ func (mr *MockManagerMockRecorder) ListUsers(ctx, accountID interface{}) *gomock
|
||||
}
|
||||
|
||||
// LoginPeer mocks base method.
|
||||
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoginPeer", ctx, login)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMap)
|
||||
ret1, _ := ret[1].(*types.Network)
|
||||
ret2, _ := ret[2].([]*posture.Checks)
|
||||
ret3, _ := ret[3].(error)
|
||||
return ret0, ret1, ret2, ret3
|
||||
ret3, _ := ret[3].(bool)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
}
|
||||
|
||||
// LoginPeer indicates an expected call of LoginPeer.
|
||||
|
||||
@@ -84,7 +84,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account
|
||||
setupKey = key.Key
|
||||
}
|
||||
|
||||
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
|
||||
_, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
|
||||
if err != nil {
|
||||
t.Error("expected to add new peer successfully after creating new account, but failed", err)
|
||||
}
|
||||
@@ -1092,7 +1092,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
}, false)
|
||||
@@ -1156,7 +1156,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedUserID := userID
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
}, false)
|
||||
@@ -1504,7 +1504,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
|
||||
peerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
|
||||
}, false)
|
||||
@@ -1826,7 +1826,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
@@ -1882,7 +1882,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
@@ -1927,7 +1927,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
}, false)
|
||||
@@ -2017,7 +2017,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
|
||||
}, false)
|
||||
@@ -2080,7 +2080,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
@@ -3276,7 +3276,7 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
Status: &nbpeer.PeerStatus{
|
||||
@@ -3444,7 +3444,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
WireGuardPubKey: account.Peers["peer-1"].Key,
|
||||
SSHKey: "someKey",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||
@@ -3513,7 +3513,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
|
||||
SSHKey: "someKey",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||
@@ -3908,13 +3908,13 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer1")
|
||||
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
}, false)
|
||||
|
||||
@@ -1663,7 +1663,7 @@ func addPeerToAccount(t *testing.T, manager *DefaultAccountManager, _, setupKeyK
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKeyKey, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKeyKey, "", &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: key.PublicKey().String()},
|
||||
}, false)
|
||||
|
||||
@@ -298,11 +298,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
|
||||
savedPeer1, _, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
|
||||
_, _, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestGroupIPv6Assignment(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "ipv6-test-host"},
|
||||
}, false)
|
||||
|
||||
@@ -479,7 +479,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
peer, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true)
|
||||
peer, _, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -728,7 +728,7 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
}
|
||||
|
||||
login := func() error {
|
||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
_, _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
if err != nil {
|
||||
t.Logf("failed to login peer: %v", err)
|
||||
return err
|
||||
@@ -746,7 +746,7 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
|
||||
go func(peerLogin types.PeerLogin, counterStart *int32) {
|
||||
defer wgPeer.Done()
|
||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
_, _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
if err != nil {
|
||||
t.Logf("failed to login peer: %v", err)
|
||||
return
|
||||
|
||||
@@ -45,7 +45,7 @@ type MockAccountManager struct {
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
|
||||
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
|
||||
@@ -98,7 +98,7 @@ type MockAccountManager struct {
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
|
||||
ExtendPeerSessionFunc func(ctx context.Context, peerPubKey, userID string) (time.Time, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
@@ -424,11 +424,11 @@ func (am *MockAccountManager) AddPeer(
|
||||
userId string,
|
||||
peer *nbpeer.Peer,
|
||||
temporary bool,
|
||||
) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
if am.AddPeerFunc != nil {
|
||||
return am.AddPeerFunc(ctx, accountID, setupKey, userId, peer, temporary)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
|
||||
return nil, nil, nil, false, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
|
||||
}
|
||||
|
||||
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
|
||||
@@ -862,11 +862,11 @@ func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, account
|
||||
}
|
||||
|
||||
// LoginPeer mocks LoginPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
if am.LoginPeerFunc != nil {
|
||||
return am.LoginPeerFunc(ctx, login)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
|
||||
return nil, nil, nil, false, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
|
||||
}
|
||||
|
||||
// ExtendPeerSession mocks ExtendPeerSession of the AccountManager interface
|
||||
|
||||
@@ -896,11 +896,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false)
|
||||
_, _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false)
|
||||
_, _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -718,10 +718,10 @@ func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, en
|
||||
// to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied
|
||||
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
||||
// The peer property is just a placeholder for the Peer properties to pass further
|
||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
|
||||
// no auth method provided => reject access
|
||||
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
return nil, nil, nil, false, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
}
|
||||
|
||||
upperKey := strings.ToUpper(setupKey)
|
||||
@@ -737,7 +737,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
// The connecting peer should be able to recover with a retry.
|
||||
_, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
|
||||
if err == nil {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||
return nil, nil, nil, false, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||
}
|
||||
|
||||
opEvent := &activity.Event{
|
||||
@@ -748,7 +748,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
peerAddConfig, err := am.processPeerAddAuth(ctx, accountID, userID, encodedHashedKey, peer, temporary, addedByUser, addedBySetupKey, opEvent)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
accountID = peerAddConfig.AccountID
|
||||
ephemeral := peerAddConfig.Ephemeral
|
||||
@@ -763,7 +763,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
}
|
||||
|
||||
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
||||
return nil, nil, nil, false, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
||||
}
|
||||
|
||||
registrationTime := time.Now().UTC()
|
||||
@@ -789,7 +789,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
}
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get account settings: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
|
||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||
@@ -807,30 +807,30 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed getting network: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed getting network: %w", err)
|
||||
}
|
||||
|
||||
maxAttempts := 10
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
netPrefix, err := netip.ParsePrefix(network.Net.String())
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("parse network prefix: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("parse network prefix: %w", err)
|
||||
}
|
||||
freeIP, err := types.AllocateRandomPeerIP(netPrefix)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed to get free IP: %w", err)
|
||||
}
|
||||
|
||||
var freeLabel string
|
||||
if ephemeral || attempt > 1 {
|
||||
freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
}
|
||||
} else {
|
||||
freeLabel, err = nbdns.GetParsedDomainLabel(peer.Meta.Hostname)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
}
|
||||
}
|
||||
newPeer.DNSLabel = freeLabel
|
||||
@@ -852,11 +852,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
if allocate {
|
||||
v6Prefix, err := netip.ParsePrefix(network.NetV6.String())
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("parse IPv6 prefix: %w", err)
|
||||
}
|
||||
freeIPv6, err := types.AllocateRandomPeerIPv6(v6Prefix)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("allocate peer IPv6: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("allocate peer IPv6: %w", err)
|
||||
}
|
||||
newPeer.IPv6 = freeIPv6
|
||||
}
|
||||
@@ -929,10 +929,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed to add peer to database: %w", err)
|
||||
}
|
||||
if newPeer == nil {
|
||||
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||
return nil, nil, nil, false, fmt.Errorf("new peer is nil")
|
||||
}
|
||||
|
||||
opEvent.TargetID = newPeer.ID
|
||||
@@ -940,7 +940,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName
|
||||
}
|
||||
if newPeer.Status != nil && newPeer.Status.RequiresApproval {
|
||||
requiresApproval := newPeer.Status != nil && newPeer.Status.RequiresApproval
|
||||
if requiresApproval {
|
||||
opEvent.Meta["pending_approval"] = true
|
||||
}
|
||||
|
||||
@@ -948,18 +949,18 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
}
|
||||
|
||||
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
network, postureChecks, enableSSH, err := getPeerLoginInfo(ctx, am.Store, accountID, newPeer, !requiresApproval)
|
||||
if err != nil {
|
||||
return p, nmap, pc, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
changedPeerIDs := []string{newPeer.ID}
|
||||
affectedPeerIDs := affectedPeerIDsFromNetworkMap(nmap, newPeer.ID)
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err := am.networkMapController.OnPeersAdded(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
|
||||
}
|
||||
|
||||
return p, nmap, pc, nil
|
||||
return newPeer, network, postureChecks, enableSSH, nil
|
||||
}
|
||||
|
||||
func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
|
||||
@@ -1041,7 +1042,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
resPeer, nmap, resPostureChecks, dnsFwdPort, err := am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
|
||||
nmap, resPostureChecks, dnsFwdPort, err := am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
@@ -1054,7 +1055,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
}
|
||||
}
|
||||
|
||||
return resPeer, nmap, resPostureChecks, dnsFwdPort, nil
|
||||
return peer, nmap, resPostureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
|
||||
@@ -1085,7 +1086,7 @@ func (am *DefaultAccountManager) markConnectedAffectedPeers(ctx context.Context,
|
||||
return affectedPeerIDsFromNetworkMap(nmap, peerID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
|
||||
// Try registering it.
|
||||
@@ -1101,12 +1102,12 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)
|
||||
return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
|
||||
return nil, nil, nil, false, status.Errorf(status.Internal, "failed while logging in peer")
|
||||
}
|
||||
|
||||
// LoginPeer logs in or registers a peer.
|
||||
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
|
||||
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return am.handlePeerLoginNotFound(ctx, login, err)
|
||||
@@ -1118,20 +1119,17 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
if login.UserID == "" {
|
||||
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var updateRemotePeers bool
|
||||
var isPeerUpdated bool
|
||||
var ipv6CapabilityChanged bool
|
||||
var postureChecks []*posture.Checks
|
||||
var shouldStorePeer bool
|
||||
var peerGroupIDs []string
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -1140,9 +1138,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return err
|
||||
}
|
||||
|
||||
// this flag prevents unnecessary calls to the persistent store.
|
||||
shouldStorePeer := false
|
||||
|
||||
if login.UserID != "" {
|
||||
if peer.UserID != login.UserID {
|
||||
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
|
||||
@@ -1156,7 +1151,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
|
||||
if changed {
|
||||
shouldStorePeer = true
|
||||
updateRemotePeers = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1165,23 +1159,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return err
|
||||
}
|
||||
|
||||
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
isPeerUpdated, _ = peer.UpdateMetaIfNew(login.Meta)
|
||||
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
if isPeerUpdated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
shouldStorePeer = true
|
||||
|
||||
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if peer.SSHKey != login.SSHKey {
|
||||
peer.SSHKey = login.SSHKey
|
||||
shouldStorePeer = true
|
||||
updateRemotePeers = true
|
||||
}
|
||||
|
||||
if !peer.AllowExtraDNSLabels && len(login.ExtraDNSLabels) > 0 {
|
||||
@@ -1197,28 +1177,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
|
||||
network, postureChecks, enableSSH, err := getPeerLoginInfo(ctx, am.Store, accountID, peer, !isRequiresApproval)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
if updateRemotePeers || isStatusChanged || ipv6CapabilityChanged || (isPeerUpdated && len(postureChecks) > 0) {
|
||||
if isStatusChanged || shouldStorePeer {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, isRequiresApproval, isPeerUpdated, len(postureChecks) > 0)
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return p, nmap, pc, nil
|
||||
return peer, network, postureChecks, enableSSH, nil
|
||||
}
|
||||
|
||||
// ExtendPeerSession refreshes the peer's SSO session deadline by updating
|
||||
@@ -1294,6 +1274,50 @@ func (am *DefaultAccountManager) ExtendPeerSession(ctx context.Context, peerPubK
|
||||
return refreshed.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration), nil
|
||||
}
|
||||
|
||||
// getPeerLoginInfo computes the login/register response data (network, posture
|
||||
// checks, SSH) from the store without building the peer's full network map.
|
||||
func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer, isValid bool) (*types.Network, []*posture.Checks, bool, error) {
|
||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, false, fmt.Errorf("get account network: %w", err)
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
return network, nil, false, nil
|
||||
}
|
||||
|
||||
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
return network, postureChecks, enableSSH, nil
|
||||
}
|
||||
|
||||
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
|
||||
for _, g := range peerGroups {
|
||||
peerGroupIDs[g.ID] = struct{}{}
|
||||
}
|
||||
|
||||
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks for the peer.
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
|
||||
@@ -205,7 +205,7 @@ func testGetNetworkMapGeneral(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
}, false)
|
||||
@@ -219,7 +219,7 @@ func testGetNetworkMapGeneral(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
}, false)
|
||||
@@ -278,7 +278,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
}, false)
|
||||
@@ -292,7 +292,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
}, false)
|
||||
@@ -454,7 +454,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
}, false)
|
||||
@@ -468,7 +468,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
}, false)
|
||||
@@ -526,7 +526,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{
|
||||
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{
|
||||
Key: peerKey1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
}, false)
|
||||
@@ -542,7 +542,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
}
|
||||
|
||||
// the second peer added with a setup key
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
}, false)
|
||||
@@ -698,7 +698,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{
|
||||
Key: peerKey1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
}, false)
|
||||
@@ -707,7 +707,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", adminUser, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", adminUser, &nbpeer.Peer{
|
||||
Key: peerKey2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
}, false)
|
||||
@@ -1332,7 +1332,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
addedPeer, _, _, err := am.AddPeer(context.Background(), "", "", existingUserID, newPeer, false)
|
||||
addedPeer, _, _, _, err := am.AddPeer(context.Background(), "", "", existingUserID, newPeer, false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
|
||||
|
||||
@@ -1465,7 +1465,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
ExtraDNSLabels: newPeerTemplate.ExtraDNSLabels,
|
||||
}
|
||||
|
||||
addedPeer, _, _, err := am.AddPeer(context.Background(), "", tc.existingSetupKeyID, "", currentPeer, false)
|
||||
addedPeer, _, _, _, err := am.AddPeer(context.Background(), "", tc.existingSetupKeyID, "", currentPeer, false)
|
||||
|
||||
if tc.expectAddPeerError {
|
||||
require.Error(t, err, "Expected an error when adding peer with setup key: %s", tc.existingSetupKeyID)
|
||||
@@ -1577,7 +1577,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
SSHEnabled: false,
|
||||
}
|
||||
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", faultyKey, "", newPeer, false)
|
||||
_, _, _, _, err = am.AddPeer(context.Background(), "", faultyKey, "", newPeer, false)
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
|
||||
@@ -1723,7 +1723,7 @@ func Test_LoginPeer(t *testing.T) {
|
||||
if sk.AllowExtraDNSLabels {
|
||||
currentPeer.ExtraDNSLabels = newPeerTemplate.ExtraDNSLabels
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", tc.setupKey, "", currentPeer, false)
|
||||
_, _, _, _, err = am.AddPeer(context.Background(), "", tc.setupKey, "", currentPeer, false)
|
||||
require.NoError(t, err, "Expected no error when adding peer with setup key: %s", tc.setupKey)
|
||||
|
||||
loginInput := types.PeerLogin{
|
||||
@@ -1739,12 +1739,12 @@ func Test_LoginPeer(t *testing.T) {
|
||||
loginInput.ExtraDNSLabels = tc.extraDNSLabels
|
||||
}
|
||||
|
||||
loggedinPeer, networkMap, postureChecks, loginErr := am.LoginPeer(context.Background(), loginInput)
|
||||
loggedinPeer, network, postureChecks, _, loginErr := am.LoginPeer(context.Background(), loginInput)
|
||||
if tc.expectLoginError {
|
||||
require.Error(t, loginErr, "Expected an error during LoginPeer with setup key: %s", tc.setupKey)
|
||||
assert.Contains(t, loginErr.Error(), tc.expectedErrorMsgSubstring, "Error message mismatch")
|
||||
assert.Nil(t, loggedinPeer, "LoggedinPeer should be nil on error")
|
||||
assert.Nil(t, networkMap, "NetworkMap should be nil on error")
|
||||
assert.Nil(t, network, "Network should be nil on error")
|
||||
assert.Nil(t, postureChecks, "PostureChecks should be empty or nil on error")
|
||||
return
|
||||
}
|
||||
@@ -1757,7 +1757,7 @@ func Test_LoginPeer(t *testing.T) {
|
||||
} else {
|
||||
assert.Equal(t, currentPeer.ExtraDNSLabels, loggedinPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch on loggedinPeer")
|
||||
}
|
||||
assert.NotNil(t, networkMap, "networkMap should not be nil on success")
|
||||
assert.NotNil(t, network, "network should not be nil on success")
|
||||
|
||||
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
|
||||
|
||||
@@ -1863,7 +1863,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{
|
||||
peer4, _, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
}, false)
|
||||
@@ -1986,7 +1986,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{
|
||||
peer4, _, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
LoginExpirationEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
@@ -2053,7 +2053,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer5, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{
|
||||
peer5, _, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
LoginExpirationEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
@@ -2108,7 +2108,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer6, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser3", &nbpeer.Peer{
|
||||
peer6, _, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser3", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
LoginExpirationEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
@@ -2286,7 +2286,7 @@ func Test_AddPeer(t *testing.T) {
|
||||
|
||||
<-start
|
||||
|
||||
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", newPeer, false)
|
||||
_, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", newPeer, false)
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
return
|
||||
@@ -2366,7 +2366,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", pendingUser.Id, peer, false)
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", pendingUser.Id, peer, false)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user pending approval cannot add peers")
|
||||
}
|
||||
@@ -2401,7 +2401,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", regularUser.Id, peer, false)
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", regularUser.Id, peer, false)
|
||||
require.NoError(t, err, "Regular user should be able to add peers")
|
||||
}
|
||||
|
||||
@@ -2444,7 +2444,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
|
||||
WtVersion: "0.28.0",
|
||||
},
|
||||
}
|
||||
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", pendingUser.Id, newPeer, false)
|
||||
existingPeer, _, _, _, err := manager.AddPeer(context.Background(), "", "", pendingUser.Id, newPeer, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now set the user back to pending approval after peer was created
|
||||
@@ -2463,7 +2463,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, err = manager.LoginPeer(context.Background(), login)
|
||||
_, _, _, _, err = manager.LoginPeer(context.Background(), login)
|
||||
require.Error(t, err)
|
||||
e, ok := status.FromError(err)
|
||||
require.True(t, ok, "error is not a gRPC status error")
|
||||
@@ -2500,7 +2500,7 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
|
||||
WtVersion: "0.28.0",
|
||||
},
|
||||
}
|
||||
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", regularUser.Id, newPeer, false)
|
||||
existingPeer, _, _, _, err := manager.AddPeer(context.Background(), "", "", regularUser.Id, newPeer, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to login with regular user
|
||||
@@ -2513,7 +2513,7 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, err = manager.LoginPeer(context.Background(), login)
|
||||
_, _, _, _, err = manager.LoginPeer(context.Background(), login)
|
||||
require.NoError(t, err, "Regular user should be able to login peers")
|
||||
}
|
||||
|
||||
@@ -2837,7 +2837,7 @@ func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) {
|
||||
// Add first peer with hostname that produces DNS label "netbird1"
|
||||
key1, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "netbird1.netbird.cloud"},
|
||||
}, false)
|
||||
@@ -2847,7 +2847,7 @@ func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) {
|
||||
// Add second peer with a different hostname
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "ip-10-29-5-130"},
|
||||
}, false)
|
||||
@@ -2871,7 +2871,7 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
|
||||
|
||||
key1, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "web-server"},
|
||||
}, false)
|
||||
@@ -2881,7 +2881,7 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
|
||||
// Add second peer and rename it to a unique FQDN whose first label doesn't collide
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "old-name"},
|
||||
}, false)
|
||||
|
||||
@@ -1156,6 +1156,47 @@ func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
|
||||
}
|
||||
|
||||
// PeerSSHEnabledFromPolicies is the network-map-free equivalent of the sshEnabled
|
||||
// determination in GetPeerConnectionResources / CalculateNetworkMapFromComponents.
|
||||
func PeerSSHEnabledFromPolicies(policies []*Policy, peerID string, peerGroupIDs map[string]struct{}, peerSSHEnabled bool) bool {
|
||||
for _, policy := range policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
isSSHRule := rule.Protocol == PolicyRuleProtocolNetbirdSSH ||
|
||||
(policyRuleImpliesLegacySSH(rule) && peerSSHEnabled)
|
||||
if !isSSHRule {
|
||||
continue
|
||||
}
|
||||
|
||||
if ruleHasDestination(rule, peerID, peerGroupIDs) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func ruleHasDestination(rule *PolicyRule, peerID string, peerGroupIDs map[string]struct{}) bool {
|
||||
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
return rule.DestinationResource.ID == peerID
|
||||
}
|
||||
|
||||
for _, groupID := range rule.Destinations {
|
||||
if _, ok := peerGroupIDs[groupID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
|
||||
for _, pr := range portRanges {
|
||||
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
|
||||
|
||||
@@ -1233,3 +1233,97 @@ func TestComponents_DisabledRuleInEnabledPolicy(t *testing.T) {
|
||||
assert.True(t, has3000, "enabled rule should generate firewall rule for port 3000")
|
||||
assert.False(t, has3001, "disabled rule should NOT generate firewall rule for port 3001")
|
||||
}
|
||||
|
||||
func peerGroupIDSet(account *types.Account, peerID string) map[string]struct{} {
|
||||
return account.GetPeerGroups(peerID)
|
||||
}
|
||||
|
||||
func assertSSHEquivalence(t *testing.T, account *types.Account, peerID string, validatedPeers map[string]struct{}) {
|
||||
t.Helper()
|
||||
nm := componentsNetworkMap(account, peerID, validatedPeers)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
got := types.PeerSSHEnabledFromPolicies(account.Policies, peerID, peerGroupIDSet(account, peerID), account.Peers[peerID].SSHEnabled)
|
||||
assert.Equalf(t, nm.EnableSSH, got, "PeerSSHEnabledFromPolicies mismatch for %s", peerID)
|
||||
}
|
||||
|
||||
func TestPeerSSHEnabledFromPolicies_MatchesMap_NetbirdSSHProtocol(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(20, 2)
|
||||
account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}}
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: "test-account",
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-ssh", Name: "Allow SSH", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Bidirectional: false,
|
||||
Sources: []string{"group-0"}, Destinations: []string{"group-1"},
|
||||
AuthorizedGroups: map[string][]string{"ssh-users": {"root"}},
|
||||
}},
|
||||
})
|
||||
|
||||
assertSSHEquivalence(t, account, "peer-10", validatedPeers)
|
||||
assertSSHEquivalence(t, account, "peer-0", validatedPeers)
|
||||
}
|
||||
|
||||
func TestPeerSSHEnabledFromPolicies_MatchesMap_NoSSHPolicy(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(20, 2)
|
||||
assertSSHEquivalence(t, account, "peer-0", validatedPeers)
|
||||
}
|
||||
|
||||
func TestPeerSSHEnabledFromPolicies_MatchesMap_LegacyImpliedSSH(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(20, 2)
|
||||
account.Peers["peer-10"].SSHEnabled = true
|
||||
assertSSHEquivalence(t, account, "peer-10", validatedPeers)
|
||||
assertSSHEquivalence(t, account, "peer-11", validatedPeers)
|
||||
}
|
||||
|
||||
func TestPeerSSHEnabledFromPolicies_MatchesMap_PeerAsDestinationResource(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2)
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-ssh-res", Name: "SSH to peer", Enabled: true, AccountID: "test-account",
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-ssh-res", Name: "SSH to peer-5", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Sources: []string{"group-0"},
|
||||
DestinationResource: types.Resource{ID: "peer-5", Type: types.ResourceTypePeer},
|
||||
}},
|
||||
})
|
||||
|
||||
assertSSHEquivalence(t, account, "peer-5", validatedPeers)
|
||||
assertSSHEquivalence(t, account, "peer-6", validatedPeers)
|
||||
}
|
||||
|
||||
func TestPeerSSHEnabledFromPolicies_MatchesMap_DisabledSSHPolicy(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2)
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-ssh-off", Name: "SSH disabled", Enabled: false, AccountID: "test-account",
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-ssh-off", Name: "Allow SSH", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Sources: []string{"group-0"}, Destinations: []string{"group-1"},
|
||||
}},
|
||||
})
|
||||
assertSSHEquivalence(t, account, "peer-10", validatedPeers)
|
||||
}
|
||||
|
||||
func TestPeerSSHEnabledFromPolicies_MatchesMap_Sweep(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(60, 6)
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-ssh-sweep", Name: "SSH sweep", Enabled: true, AccountID: "test-account",
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-ssh-sweep", Name: "Allow SSH", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Sources: []string{"group-0"}, Destinations: []string{"group-2"},
|
||||
}},
|
||||
})
|
||||
for peerID := range account.Peers {
|
||||
account.Peers[peerID].SSHEnabled = len(peerID)%2 == 0
|
||||
}
|
||||
|
||||
for peerID := range account.Peers {
|
||||
if _, ok := validatedPeers[peerID]; !ok {
|
||||
continue
|
||||
}
|
||||
assertSSHEquivalence(t, account, peerID, validatedPeers)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1565,7 +1565,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer4, _, _, err := manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{
|
||||
peer4, _, _, _, err := manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
}, false)
|
||||
|
||||
Reference in New Issue
Block a user