mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 07:46:38 +00:00
Compare commits
6 Commits
v0.23.7
...
fix/events
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56d82a99e1 | ||
|
|
d25f543913 | ||
|
|
672b4c0401 | ||
|
|
1219006a6e | ||
|
|
4791e41004 | ||
|
|
9131069d12 |
21
.github/workflows/test-infrastructure-files.yml
vendored
21
.github/workflows/test-infrastructure-files.yml
vendored
@@ -112,6 +112,27 @@ jobs:
|
||||
grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||
grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: Build management binary
|
||||
working-directory: management
|
||||
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
|
||||
|
||||
- name: Build management docker image
|
||||
working-directory: management
|
||||
run: |
|
||||
docker build -t netbirdio/management:latest .
|
||||
|
||||
- name: Build signal binary
|
||||
working-directory: signal
|
||||
run: CGO_ENABLED=0 go build -o netbird-signal main.go
|
||||
|
||||
- name: Build signal docker image
|
||||
working-directory: signal
|
||||
run: |
|
||||
docker build -t netbirdio/signal:latest .
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files
|
||||
run: |
|
||||
|
||||
@@ -170,6 +170,10 @@ if [ "$NETBIRD_DASH_AUTH_USE_AUDIENCE" = "false" ]; then
|
||||
export NETBIRD_AUTH_PKCE_AUDIENCE=
|
||||
fi
|
||||
|
||||
# Read the encryption key
|
||||
encKey=$(grep DataStoreEncryptionKey management.json | awk -F'"' '{$0=$4}1')
|
||||
export NETBIRD_DATASTORE_ENC_KEY=$encKey
|
||||
|
||||
env | grep NETBIRD
|
||||
|
||||
envsubst <docker-compose.yml.tmpl >docker-compose.yml
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
"Password": null
|
||||
},
|
||||
"Datadir": "",
|
||||
"DataStoreEncryptionKey": "$NETBIRD_DATASTORE_ENC_KEY",
|
||||
"HttpConfig": {
|
||||
"Address": "0.0.0.0:$NETBIRD_MGMT_API_PORT",
|
||||
"AuthIssuer": "$NETBIRD_AUTH_AUTHORITY",
|
||||
|
||||
@@ -301,15 +301,23 @@ var (
|
||||
func initEventStore(dataDir string, key string) (activity.Store, string, error) {
|
||||
var err error
|
||||
if key == "" {
|
||||
log.Debugf("generate new activity store encryption key")
|
||||
key, err = sqlite.GenerateKey()
|
||||
log.Debugf("restore or generate new activity store encryption key")
|
||||
key, err = sqlite.RestoreKey(dataDir)
|
||||
if err == nil {
|
||||
goto CreateStore
|
||||
} else {
|
||||
log.Debugf("failed to restore encryption key for activity store: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("generate new encryption key for activity store")
|
||||
key, err = sqlite.GenerateKey(dataDir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
CreateStore:
|
||||
store, err := sqlite.NewSQLiteStore(dataDir, key)
|
||||
return store, key, err
|
||||
|
||||
}
|
||||
|
||||
func notifyStop(msg string) {
|
||||
|
||||
@@ -62,12 +62,9 @@ type AccountManager interface {
|
||||
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
MarkPATUsed(tokenID string) error
|
||||
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||
AccountExists(accountId string) (*bool, error)
|
||||
GetPeerByKey(peerKey string) (*Peer, error)
|
||||
GetPeers(accountID, userID string) ([]*Peer, error)
|
||||
MarkPeerConnected(peerKey string, connected bool) error
|
||||
DeletePeer(accountID, peerID, userID string) error
|
||||
GetPeerByIP(accountId string, peerIP string) (*Peer, error)
|
||||
UpdatePeer(accountID, userID string, peer *Peer) (*Peer, error)
|
||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||
GetPeerNetwork(peerID string) (*Network, error)
|
||||
@@ -84,7 +81,6 @@ type AccountManager interface {
|
||||
ListGroups(accountId string) ([]*Group, error)
|
||||
GroupAddPeer(accountId, groupID, peerID string) error
|
||||
GroupDeletePeer(accountId, groupID, peerID string) error
|
||||
GroupListPeers(accountId, groupID string) ([]*Peer, error)
|
||||
GetPolicy(accountID, policyID, userID string) (*Policy, error)
|
||||
SavePolicy(accountID, userID string, policy *Policy) error
|
||||
DeletePolicy(accountID, policyID, userID string) error
|
||||
@@ -303,17 +299,6 @@ func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
|
||||
return routes
|
||||
}
|
||||
|
||||
// GetPeerByIP returns peer by it's IP if exists under account or nil otherwise
|
||||
func (a *Account) GetPeerByIP(peerIP string) *Peer {
|
||||
for _, peer := range a.Peers {
|
||||
if peerIP == peer.IP.String() {
|
||||
return peer
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGroup returns a group by ID if exists, nil otherwise
|
||||
func (a *Account) GetGroup(groupID string) *Group {
|
||||
return a.Groups[groupID]
|
||||
@@ -1491,9 +1476,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
if err := am.Store.SaveAccount(account); err != nil {
|
||||
log.Errorf("failed to save account: %v", err)
|
||||
} else {
|
||||
if err := am.updateAccountPeers(account); err != nil {
|
||||
log.Errorf("failed updating account peers while updating user %s", account.Id)
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
for _, g := range addNewGroups {
|
||||
if group := account.GetGroup(g); group != nil {
|
||||
am.storeEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
||||
@@ -1604,26 +1587,6 @@ func isDomainValid(domain string) bool {
|
||||
return re.Match([]byte(domain))
|
||||
}
|
||||
|
||||
// AccountExists checks whether account exists (returns true) or not (returns false)
|
||||
func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
var res bool
|
||||
_, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
res = false
|
||||
return &res, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
res = true
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
func (am *DefaultAccountManager) GetDNSDomain() string {
|
||||
return am.dnsDomain
|
||||
|
||||
@@ -706,30 +706,6 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string)
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func TestAccountManager_AccountExists(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
expectedId := "test_account"
|
||||
userId := "account_creator"
|
||||
_, err = createAccount(manager, expectedId, userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
exists, err := manager.AccountExists(expectedId)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !*exists {
|
||||
t.Errorf("expected account to exist after creation, got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,12 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
const (
|
||||
backupFile = ".datastore.key"
|
||||
)
|
||||
|
||||
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
|
||||
@@ -15,16 +21,40 @@ type FieldEncrypt struct {
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
func GenerateKey() (string, error) {
|
||||
func RestoreKey(dataDir string) (string, error) {
|
||||
fName := filepath.Join(dataDir, backupFile)
|
||||
data, err := os.ReadFile(fName)
|
||||
return string(data), err
|
||||
}
|
||||
|
||||
func GenerateKey(dataDir string) (string, error) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
readableKey := base64.StdEncoding.EncodeToString(key)
|
||||
|
||||
err = saveKey(dataDir, readableKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return readableKey, nil
|
||||
}
|
||||
|
||||
func saveKey(dataDir, key string) error {
|
||||
f, err := os.Create(filepath.Join(dataDir, backupFile))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
_, err = f.WriteString(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
|
||||
binKey, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,11 +2,20 @@ package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestRestoreKey(t *testing.T) {
|
||||
_, err := RestoreKey(t.TempDir())
|
||||
if err != nil {
|
||||
log.Infof("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
key, err := GenerateKey(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
@@ -32,7 +41,7 @@ func TestGenerateKey(t *testing.T) {
|
||||
|
||||
func TestCorruptKey(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
key, err := GenerateKey(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
@@ -46,7 +55,7 @@ func TestCorruptKey(t *testing.T) {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
|
||||
newKey, err := GenerateKey()
|
||||
newKey, err := GenerateKey(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
|
||||
@@ -45,6 +45,9 @@ const (
|
||||
"VALUES(?, ?, ?, ?, ?, ?)"
|
||||
|
||||
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
|
||||
|
||||
fallbackName = "unknown"
|
||||
fallbackEmail = "unknown@unknown.com"
|
||||
)
|
||||
|
||||
// Store is the implementation of the activity.Store interface backed by SQLite
|
||||
@@ -128,6 +131,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
|
||||
|
||||
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
events := make([]*activity.Event, 0)
|
||||
var cryptErr error
|
||||
for result.Next() {
|
||||
var id int64
|
||||
var operation activity.Activity
|
||||
@@ -156,8 +160,8 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
if targetUserName != nil {
|
||||
name, err := store.fieldEncrypt.Decrypt(*targetUserName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt username for target id: %s", target)
|
||||
meta["username"] = ""
|
||||
cryptErr = fmt.Errorf("failed to decrypt username for target id: %s", target)
|
||||
meta["username"] = fallbackName
|
||||
} else {
|
||||
meta["username"] = name
|
||||
}
|
||||
@@ -166,8 +170,8 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
if targetEmail != nil {
|
||||
email, err := store.fieldEncrypt.Decrypt(*targetEmail)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt email address for target id: %s", target)
|
||||
meta["email"] = ""
|
||||
cryptErr = fmt.Errorf("failed to decrypt email address for target id: %s", target)
|
||||
meta["email"] = fallbackEmail
|
||||
} else {
|
||||
meta["email"] = email
|
||||
}
|
||||
@@ -186,7 +190,8 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
if initiatorName != nil {
|
||||
name, err := store.fieldEncrypt.Decrypt(*initiatorName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt username of initiator: %s", initiator)
|
||||
cryptErr = fmt.Errorf("failed to decrypt username of initiator: %s", initiator)
|
||||
event.InitiatorName = fallbackName
|
||||
} else {
|
||||
event.InitiatorName = name
|
||||
}
|
||||
@@ -195,7 +200,8 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
if initiatorEmail != nil {
|
||||
email, err := store.fieldEncrypt.Decrypt(*initiatorEmail)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt email address of initiator: %s", initiator)
|
||||
cryptErr = fmt.Errorf("failed to decrypt email address of initiator: %s", initiator)
|
||||
event.InitiatorEmail = fallbackEmail
|
||||
} else {
|
||||
event.InitiatorEmail = email
|
||||
}
|
||||
@@ -204,6 +210,10 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
if cryptErr != nil {
|
||||
log.Warnf("%s", cryptErr)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
func TestNewSQLiteStore(t *testing.T) {
|
||||
dataDir := t.TempDir()
|
||||
key, _ := GenerateKey()
|
||||
key, _ := GenerateKey(dataDir)
|
||||
store, err := NewSQLiteStore(dataDir, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -122,7 +122,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
|
||||
am.storeEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
}
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
|
||||
|
||||
@@ -84,10 +84,7 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *G
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
// the following snippet tracks the activity and stores the group events in the event store.
|
||||
// It has to happen after all the operations have been successfully performed.
|
||||
@@ -229,7 +226,9 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
|
||||
|
||||
am.storeEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListGroups objects of the peers
|
||||
@@ -281,7 +280,9 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
|
||||
return err
|
||||
}
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
@@ -309,31 +310,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID stri
|
||||
}
|
||||
}
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
}
|
||||
|
||||
// GroupListPeers returns list of the peers from the group
|
||||
func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
peers := make([]*Peer, 0, len(account.Groups))
|
||||
for _, peerID := range group.Peers {
|
||||
p, ok := account.Peers[peerID]
|
||||
if ok {
|
||||
peers = append(peers, p)
|
||||
}
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -53,14 +53,6 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
|
||||
Issued: server.GroupIssuedAPI,
|
||||
}, nil
|
||||
},
|
||||
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
||||
for _, peer := range TestPeers {
|
||||
if peer.IP.String() == peerIP {
|
||||
return peer, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("peer not found")
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
|
||||
@@ -125,15 +125,6 @@ func initRoutesTestData() *RoutesHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
||||
if peerIP != existingPeerID {
|
||||
return nil, status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
|
||||
}
|
||||
return &server.Peer{
|
||||
Key: existingPeerKey,
|
||||
IP: netip.MustParseAddr(existingPeerID).AsSlice(),
|
||||
}, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingAccount, testingAccount.Users["test_user"], nil
|
||||
},
|
||||
|
||||
@@ -20,12 +20,9 @@ type MockAccountManager struct {
|
||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
AccountExistsFunc func(accountId string) (*bool, error)
|
||||
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
|
||||
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool) error
|
||||
DeletePeerFunc func(accountID, peerKey, userID string) error
|
||||
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error)
|
||||
@@ -35,7 +32,6 @@ type MockAccountManager struct {
|
||||
ListGroupsFunc func(accountID string) ([]*server.Group, error)
|
||||
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
||||
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
|
||||
GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error)
|
||||
SaveRuleFunc func(accountID, userID string, rule *server.Rule) error
|
||||
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
||||
@@ -140,22 +136,6 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
||||
)
|
||||
}
|
||||
|
||||
// AccountExists mock implementation of AccountExists from server.AccountManager interface
|
||||
func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) {
|
||||
if am.AccountExistsFunc != nil {
|
||||
return am.AccountExistsFunc(accountId)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented")
|
||||
}
|
||||
|
||||
// GetPeerByKey mocks implementation of GetPeerByKey from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPeerByKey(peerKey string) (*server.Peer, error) {
|
||||
if am.GetPeerByKeyFunc != nil {
|
||||
return am.GetPeerByKeyFunc(peerKey)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerByKey is not implemented")
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
@@ -164,14 +144,6 @@ func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool)
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) {
|
||||
if am.GetPeerByIPFunc != nil {
|
||||
return am.GetPeerByIPFunc(accountId, peerIP)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerByIP is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
if am.GetAccountFromPATFunc != nil {
|
||||
@@ -296,14 +268,6 @@ func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerID string)
|
||||
return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented")
|
||||
}
|
||||
|
||||
// GroupListPeers mock implementation of GroupListPeers from server.AccountManager interface
|
||||
func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*server.Peer, error) {
|
||||
if am.GroupListPeersFunc != nil {
|
||||
return am.GroupListPeersFunc(accountID, groupID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers is not implemented")
|
||||
}
|
||||
|
||||
// GetRule mock implementation of GetRule from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) {
|
||||
if am.GetRuleFunc != nil {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -74,11 +73,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after create nameserver %s", name)
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
am.storeEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
@@ -113,11 +108,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return status.Errorf(status.Internal, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
am.storeEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
@@ -147,10 +138,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to update peers after deleting nameserver %s", nsGroupID)
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
am.storeEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
|
||||
@@ -195,16 +195,6 @@ func (p *PeerStatus) Copy() *PeerStatus {
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeerByKey looks up peer by its public WireGuard key
|
||||
func (am *DefaultAccountManager) GetPeerByKey(peerPubKey string) (*Peer, error) {
|
||||
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account.FindPeerByPubKey(peerPubKey)
|
||||
}
|
||||
|
||||
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||
// the current user is not an admin.
|
||||
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) {
|
||||
@@ -290,10 +280,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
|
||||
if oldStatus.LoginExpired {
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -364,10 +351,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *Pe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
@@ -433,26 +417,9 @@ func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) er
|
||||
return err
|
||||
}
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
// GetPeerByIP returns peer by its IP
|
||||
func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*Peer, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if peerIP == peer.IP.String() {
|
||||
return peer, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
@@ -622,10 +589,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||
am.storeEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain)
|
||||
return newPeer, networkMap, nil
|
||||
@@ -740,10 +704,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
|
||||
}
|
||||
|
||||
if updateRemotePeers {
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
}
|
||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
|
||||
}
|
||||
@@ -817,10 +778,7 @@ func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(peer *Peer, account *A
|
||||
}
|
||||
|
||||
// trigger network map update
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
@@ -865,7 +823,9 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
|
||||
}
|
||||
|
||||
// trigger network map update
|
||||
return am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
@@ -922,18 +882,12 @@ func updatePeerMeta(peer *Peer, meta PeerSystemMeta, account *Account) (*Peer, b
|
||||
|
||||
// updateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
|
||||
func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
|
||||
peers := account.GetPeers()
|
||||
|
||||
for _, peer := range peers {
|
||||
remotePeerNetworkMap, err := am.GetNetworkMap(peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain)
|
||||
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
|
||||
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -350,7 +350,9 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
|
||||
}
|
||||
am.storeEvent(userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePolicy from the store
|
||||
@@ -375,7 +377,9 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
|
||||
|
||||
am.storeEvent(userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPolicies from the store
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GetRoute gets a route object from account and route IDs
|
||||
@@ -185,11 +184,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return &newRoute, status.Errorf(status.Internal, "failed to update peers after create route %s", newPrefix)
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
am.storeEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
@@ -250,10 +245,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
am.storeEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
@@ -283,7 +275,9 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
|
||||
|
||||
am.storeEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta())
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
|
||||
@@ -317,7 +317,9 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
|
||||
}
|
||||
}()
|
||||
|
||||
return newKey, am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return newKey, nil
|
||||
}
|
||||
|
||||
// ListSetupKeys returns a list of all setup keys of the account
|
||||
|
||||
@@ -377,7 +377,9 @@ func (am *DefaultAccountManager) deleteRegularUser(account *Account, initiatorUs
|
||||
meta := map[string]any{"name": tuName, "email": tuEmail}
|
||||
am.storeEvent(initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetUserID string, account *Account) error {
|
||||
@@ -674,9 +676,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := am.updateAccountPeers(account); err != nil {
|
||||
log.Errorf("failed updating account peers while updating user %s", accountID)
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
} else {
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return nil, err
|
||||
@@ -870,9 +870,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []
|
||||
if len(peerIDs) != 0 {
|
||||
// this will trigger peer disconnect from the management service
|
||||
am.peersUpdateManager.CloseChannels(peerIDs)
|
||||
if err := am.updateAccountPeers(account); err != nil {
|
||||
return err
|
||||
}
|
||||
am.updateAccountPeers(account)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user